use crate::audit::{AuditConfig, AuditLogger};
use axum::{
extract::Request,
http::{header, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
Json,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Default)]
pub struct AuthMetrics {
pub failures_total: AtomicU64,
pub failures_missing_key: AtomicU64,
pub failures_invalid_key: AtomicU64,
pub failures_expired_key: AtomicU64,
pub failures_disabled_key: AtomicU64,
pub failures_insufficient_scope: AtomicU64,
}
impl AuthMetrics {
pub fn new() -> Self {
Self::default()
}
pub fn record_failure(&self, reason: AuthFailureReason) {
self.failures_total.fetch_add(1, Ordering::Relaxed);
match reason {
AuthFailureReason::MissingKey => {
self.failures_missing_key.fetch_add(1, Ordering::Relaxed);
},
AuthFailureReason::InvalidKey => {
self.failures_invalid_key.fetch_add(1, Ordering::Relaxed);
},
AuthFailureReason::ExpiredKey => {
self.failures_expired_key.fetch_add(1, Ordering::Relaxed);
},
AuthFailureReason::DisabledKey => {
self.failures_disabled_key.fetch_add(1, Ordering::Relaxed);
},
AuthFailureReason::InsufficientScope => {
self.failures_insufficient_scope
.fetch_add(1, Ordering::Relaxed);
},
}
}
pub fn total_failures(&self) -> u64 {
self.failures_total.load(Ordering::Relaxed)
}
pub fn render_prometheus(&self) -> String {
let mut output = String::new();
output.push_str("# HELP infernum_auth_failures_total Total authentication failures.\n");
output.push_str("# TYPE infernum_auth_failures_total counter\n");
output.push_str(&format!(
"infernum_auth_failures_total{{reason=\"missing_key\"}} {}\n",
self.failures_missing_key.load(Ordering::Relaxed)
));
output.push_str(&format!(
"infernum_auth_failures_total{{reason=\"invalid_key\"}} {}\n",
self.failures_invalid_key.load(Ordering::Relaxed)
));
output.push_str(&format!(
"infernum_auth_failures_total{{reason=\"expired_key\"}} {}\n",
self.failures_expired_key.load(Ordering::Relaxed)
));
output.push_str(&format!(
"infernum_auth_failures_total{{reason=\"disabled_key\"}} {}\n",
self.failures_disabled_key.load(Ordering::Relaxed)
));
output.push_str(&format!(
"infernum_auth_failures_total{{reason=\"insufficient_scope\"}} {}\n",
self.failures_insufficient_scope.load(Ordering::Relaxed)
));
output
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthFailureReason {
MissingKey,
InvalidKey,
ExpiredKey,
DisabledKey,
InsufficientScope,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Permission {
User,
Admin,
}
impl Default for Permission {
fn default() -> Self {
Self::User
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Scope {
Inference,
Admin,
Metrics,
}
impl Scope {
pub fn as_str(&self) -> &'static str {
match self {
Self::Inference => "inference",
Self::Admin => "admin",
Self::Metrics => "metrics",
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"inference" | "inf" => Some(Self::Inference),
"admin" | "adm" => Some(Self::Admin),
"metrics" | "met" => Some(Self::Metrics),
_ => None,
}
}
}
pub fn required_scope_for_path(path: &str) -> Option<Scope> {
if path == "/health" || path == "/ready" {
return None;
}
if path.starts_with("/api/models/load")
|| path.starts_with("/api/models/unload")
|| path.starts_with("/api/keys")
|| path.starts_with("/api/config")
|| path.starts_with("/admin/models")
{
return Some(Scope::Admin);
}
if path.starts_with("/v1/chat")
|| path.starts_with("/v1/completions")
|| path.starts_with("/v1/embeddings")
|| path.starts_with("/v1/models")
{
return Some(Scope::Inference);
}
if path.starts_with("/metrics") {
return Some(Scope::Metrics);
}
if path.starts_with("/v1/") || path.starts_with("/api/") {
return Some(Scope::Inference);
}
None
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ApiKey {
pub key: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub key_hash: Option<String>,
pub permission: Permission,
#[serde(default)]
pub scopes: Vec<Scope>,
pub name: Option<String>,
pub enabled: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
#[serde(default = "chrono::Utc::now")]
pub created_at: chrono::DateTime<chrono::Utc>,
}
impl ApiKey {
pub fn user(key: impl Into<String>) -> Self {
Self {
key: key.into(),
key_hash: None,
permission: Permission::User,
scopes: vec![Scope::Inference],
name: None,
enabled: true,
expires_at: None,
created_at: chrono::Utc::now(),
}
}
pub fn admin(key: impl Into<String>) -> Self {
Self {
key: key.into(),
key_hash: None,
permission: Permission::Admin,
scopes: vec![Scope::Inference, Scope::Admin, Scope::Metrics],
name: None,
enabled: true,
expires_at: None,
created_at: chrono::Utc::now(),
}
}
pub fn with_scopes(key: impl Into<String>, scopes: Vec<Scope>) -> Self {
let permission = if scopes.contains(&Scope::Admin) {
Permission::Admin
} else {
Permission::User
};
Self {
key: key.into(),
key_hash: None,
permission,
scopes,
name: None,
enabled: true,
expires_at: None,
created_at: chrono::Utc::now(),
}
}
pub fn with_expiration(mut self, expires_at: chrono::DateTime<chrono::Utc>) -> Self {
self.expires_at = Some(expires_at);
self
}
pub fn with_expiration_duration(mut self, duration: &str) -> Self {
self.expires_at = Self::parse_duration_to_expiry(duration);
self
}
fn parse_duration_to_expiry(duration: &str) -> Option<chrono::DateTime<chrono::Utc>> {
use chrono::Duration;
if let Some(days) = duration.strip_suffix('d') {
if let Ok(d) = days.parse::<i64>() {
return Some(chrono::Utc::now() + Duration::days(d));
}
}
if let Some(hours) = duration.strip_suffix('h') {
if let Ok(h) = hours.parse::<i64>() {
return Some(chrono::Utc::now() + Duration::hours(h));
}
}
if let Some(mins) = duration.strip_suffix('m') {
if let Ok(m) = mins.parse::<i64>() {
return Some(chrono::Utc::now() + Duration::minutes(m));
}
}
chrono::DateTime::parse_from_rfc3339(duration)
.ok()
.map(|dt| dt.with_timezone(&chrono::Utc))
}
pub fn is_expired(&self) -> bool {
if let Some(expires) = self.expires_at {
chrono::Utc::now() > expires
} else {
false
}
}
pub fn parse_scope_from_key(key: &str) -> Option<Scope> {
if !key.starts_with("sk-") {
return None;
}
let parts: Vec<&str> = key.split('-').collect();
if parts.len() < 3 {
return None;
}
Scope::from_str(parts[1])
}
pub fn hash_key(&self) -> String {
Self::hash_key_sha256(&self.key)
}
#[must_use]
pub fn hash_key_sha256(key: &str) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
let result = hasher.finalize();
format!("sha256:{}", hex::encode(result))
}
pub fn verify_key(plaintext: &str, hash: &str) -> bool {
use subtle::ConstantTimeEq;
if let Some(sha_hash) = hash.strip_prefix("sha256:") {
let computed = Self::hash_key_sha256(plaintext);
if let Some(computed_hash) = computed.strip_prefix("sha256:") {
let expected = hex::decode(sha_hash).unwrap_or_default();
let actual = hex::decode(computed_hash).unwrap_or_default();
if expected.len() != actual.len() {
return false;
}
expected.ct_eq(&actual).into()
} else {
false
}
} else {
tracing::warn!("Verifying key against non-prefixed hash (legacy format)");
let a = plaintext.as_bytes();
let b = hash.as_bytes();
if a.len() != b.len() {
return false;
}
a.ct_eq(b).into()
}
}
#[must_use]
pub fn hashed(mut self) -> Self {
self.key_hash = Some(self.hash_key());
self.key = "[HASHED]".to_string();
self
}
#[must_use]
pub fn key_prefix(&self) -> &str {
let end = std::cmp::min(8, self.key.len());
&self.key[..end]
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn has_scope(&self, scope: Scope) -> bool {
self.scopes.contains(&scope)
}
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct AuthConfig {
pub enabled: bool,
pub api_keys: HashMap<String, ApiKey>,
pub public_paths: Vec<String>,
}
impl AuthConfig {
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
pub fn enabled() -> Self {
Self {
enabled: true,
public_paths: vec![
"/health".to_string(),
"/ready".to_string(),
"/metrics".to_string(),
],
..Default::default()
}
}
pub fn add_key(mut self, api_key: ApiKey) -> Self {
self.api_keys.insert(api_key.key.clone(), api_key);
self
}
pub fn add_public_path(mut self, path: impl Into<String>) -> Self {
self.public_paths.push(path.into());
self
}
pub fn from_env() -> Self {
let mut config = Self::enabled();
if let Ok(keys_str) = std::env::var("INFERNUM_API_KEYS") {
for pair in keys_str.split(',') {
let parts: Vec<&str> = pair.trim().split(':').collect();
match parts.as_slice() {
[key, "admin"] => {
config = config.add_key(ApiKey::admin(*key));
},
[key, "user"] | [key] => {
config = config.add_key(ApiKey::user(*key));
},
_ => {
tracing::warn!("Invalid API key format: {}", pair);
},
}
}
}
if config.api_keys.is_empty() {
if let Ok(key) = std::env::var("INFERNUM_API_KEY") {
config = config.add_key(ApiKey::admin(key));
}
}
if config.api_keys.is_empty() {
tracing::warn!("No API keys configured, authentication disabled");
config.enabled = false;
}
config
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn key_count(&self) -> usize {
self.api_keys.len()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ValidationResult {
Valid(Permission),
NotFound,
Disabled,
Expired,
}
#[derive(Debug, Clone)]
pub struct AuthState {
config: Arc<RwLock<AuthConfig>>,
audit_logger: AuditLogger,
metrics: Arc<AuthMetrics>,
}
impl AuthState {
pub fn new(config: AuthConfig) -> Self {
Self {
config: Arc::new(RwLock::new(config)),
audit_logger: AuditLogger::new(AuditConfig::default()),
metrics: Arc::new(AuthMetrics::new()),
}
}
pub fn with_audit_config(config: AuthConfig, audit_config: AuditConfig) -> Self {
Self {
config: Arc::new(RwLock::new(config)),
audit_logger: AuditLogger::new(audit_config),
metrics: Arc::new(AuthMetrics::new()),
}
}
pub fn audit_logger(&self) -> &AuditLogger {
&self.audit_logger
}
pub fn metrics(&self) -> &AuthMetrics {
&self.metrics
}
pub async fn is_public_path(&self, path: &str) -> bool {
let config = self.config.read().await;
if !config.enabled {
return true;
}
config.public_paths.iter().any(|p| path.starts_with(p))
}
pub async fn validate_key(&self, key: &str) -> Option<Permission> {
match self.validate_key_detailed(key).await {
ValidationResult::Valid(permission) => Some(permission),
_ => None,
}
}
pub async fn validate_key_detailed(&self, key: &str) -> ValidationResult {
let config = self.config.read().await;
match config.api_keys.get(key) {
None => ValidationResult::NotFound,
Some(api_key) => {
if !api_key.enabled {
tracing::debug!(key_prefix = api_key.key_prefix(), "Key is disabled");
return ValidationResult::Disabled;
}
if api_key.is_expired() {
tracing::debug!(
key_prefix = api_key.key_prefix(),
expires_at = ?api_key.expires_at,
"Key has expired"
);
return ValidationResult::Expired;
}
ValidationResult::Valid(api_key.permission)
},
}
}
pub async fn is_enabled(&self) -> bool {
self.config.read().await.enabled
}
pub async fn add_key(&self, api_key: ApiKey) {
let mut config = self.config.write().await;
config.api_keys.insert(api_key.key.clone(), api_key);
}
pub async fn remove_key(&self, key: &str) {
let mut config = self.config.write().await;
config.api_keys.remove(key);
}
pub async fn has_scope(&self, key: &str, scope: Scope) -> bool {
let config = self.config.read().await;
config.api_keys.get(key).map_or(false, |api_key| {
api_key.enabled && api_key.scopes.contains(&scope)
})
}
}
impl Default for AuthState {
fn default() -> Self {
Self::new(AuthConfig::disabled())
}
}
#[derive(Debug, Serialize)]
struct AuthError {
error: AuthErrorDetail,
}
#[derive(Debug, Serialize)]
struct AuthErrorDetail {
message: String,
#[serde(rename = "type")]
error_type: String,
code: String,
}
impl AuthError {
fn unauthorized(message: &str) -> Self {
Self {
error: AuthErrorDetail {
message: message.to_string(),
error_type: "authentication_error".to_string(),
code: "invalid_api_key".to_string(),
},
}
}
fn forbidden(message: &str) -> Self {
Self {
error: AuthErrorDetail {
message: message.to_string(),
error_type: "authorization_error".to_string(),
code: "insufficient_permissions".to_string(),
},
}
}
}
fn extract_api_key(request: &Request) -> Option<String> {
if let Some(auth_header) = request.headers().get(header::AUTHORIZATION) {
if let Ok(auth_str) = auth_header.to_str() {
if let Some(token) = auth_str.strip_prefix("Bearer ") {
return Some(token.trim().to_string());
}
}
}
if let Some(api_key_header) = request.headers().get("x-api-key") {
if let Ok(key_str) = api_key_header.to_str() {
return Some(key_str.trim().to_string());
}
}
None
}
fn extract_client_ip(request: &Request) -> Option<String> {
if let Some(forwarded) = request.headers().get("x-forwarded-for") {
if let Ok(value) = forwarded.to_str() {
if let Some(ip) = value.split(',').next() {
return Some(ip.trim().to_string());
}
}
}
if let Some(real_ip) = request.headers().get("x-real-ip") {
if let Ok(ip) = real_ip.to_str() {
return Some(ip.trim().to_string());
}
}
None
}
fn generate_request_id() -> String {
format!("req-{}", uuid::Uuid::new_v4().as_simple())
}
pub async fn auth_middleware(auth_state: AuthState, request: Request, next: Next) -> Response {
let path = request.uri().path().to_string();
let client_ip = extract_client_ip(&request);
let request_id = generate_request_id();
if auth_state.is_public_path(&path).await {
return next.run(request).await;
}
if !auth_state.is_enabled().await {
return next.run(request).await;
}
let api_key = match extract_api_key(&request) {
Some(key) => key,
None => {
auth_state
.metrics()
.record_failure(AuthFailureReason::MissingKey);
auth_state
.audit_logger()
.auth_failure(&request_id, client_ip.as_deref(), "missing_api_key", &path)
.await;
return (
StatusCode::UNAUTHORIZED,
Json(AuthError::unauthorized(
"Missing API key. Include it in Authorization header as 'Bearer sk-xxx' or X-API-Key header.",
)),
)
.into_response();
},
};
let validation_result = auth_state.validate_key_detailed(&api_key).await;
let permission = match validation_result {
ValidationResult::Valid(perm) => perm,
ValidationResult::NotFound => {
auth_state
.metrics()
.record_failure(AuthFailureReason::InvalidKey);
auth_state
.audit_logger()
.auth_failure(&request_id, client_ip.as_deref(), "invalid_key", &path)
.await;
return (
StatusCode::UNAUTHORIZED,
Json(AuthError::unauthorized("Invalid API key")),
)
.into_response();
},
ValidationResult::Disabled => {
auth_state
.metrics()
.record_failure(AuthFailureReason::DisabledKey);
auth_state
.audit_logger()
.auth_disabled(&request_id, client_ip.as_deref(), &api_key, &path)
.await;
return (
StatusCode::UNAUTHORIZED,
Json(AuthError::unauthorized("API key is disabled")),
)
.into_response();
},
ValidationResult::Expired => {
auth_state
.metrics()
.record_failure(AuthFailureReason::ExpiredKey);
auth_state
.audit_logger()
.auth_expired(&request_id, client_ip.as_deref(), &api_key, &path)
.await;
return (
StatusCode::UNAUTHORIZED,
Json(AuthError::unauthorized("API key has expired")),
)
.into_response();
},
};
if requires_admin_permission(&path) && permission != Permission::Admin {
auth_state
.metrics()
.record_failure(AuthFailureReason::InsufficientScope);
auth_state
.audit_logger()
.scope_violation(&request_id, client_ip.as_deref(), &api_key, &path, "admin")
.await;
return (
StatusCode::FORBIDDEN,
Json(AuthError::forbidden(
"This endpoint requires admin permissions",
)),
)
.into_response();
}
auth_state
.audit_logger()
.auth_success(&request_id, client_ip.as_deref(), &api_key, &path)
.await;
next.run(request).await
}
fn requires_admin_permission(path: &str) -> bool {
let admin_paths = [
"/api/models/load",
"/api/models/unload",
"/api/keys",
"/api/config",
];
admin_paths.iter().any(|p| path.starts_with(p))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_api_key_user() {
let key = ApiKey::user("sk-test123");
assert_eq!(key.key, "sk-test123");
assert_eq!(key.permission, Permission::User);
assert!(key.enabled);
}
#[test]
fn test_api_key_admin() {
let key = ApiKey::admin("sk-admin456");
assert_eq!(key.key, "sk-admin456");
assert_eq!(key.permission, Permission::Admin);
assert!(key.enabled);
}
#[test]
fn test_api_key_with_name() {
let key = ApiKey::user("sk-test").with_name("Production Key");
assert_eq!(key.name, Some("Production Key".to_string()));
}
#[test]
fn test_auth_config_disabled() {
let config = AuthConfig::disabled();
assert!(!config.enabled);
}
#[test]
fn test_auth_config_enabled() {
let config = AuthConfig::enabled();
assert!(config.enabled);
assert!(config.public_paths.contains(&"/health".to_string()));
assert!(config.public_paths.contains(&"/ready".to_string()));
}
#[test]
fn test_auth_config_add_key() {
let config = AuthConfig::enabled()
.add_key(ApiKey::user("sk-user1"))
.add_key(ApiKey::admin("sk-admin1"));
assert_eq!(config.api_keys.len(), 2);
assert!(config.api_keys.contains_key("sk-user1"));
assert!(config.api_keys.contains_key("sk-admin1"));
}
#[tokio::test]
async fn test_auth_state_validate_key() {
let config = AuthConfig::enabled()
.add_key(ApiKey::user("sk-user"))
.add_key(ApiKey::admin("sk-admin"));
let state = AuthState::new(config);
assert_eq!(state.validate_key("sk-user").await, Some(Permission::User));
assert_eq!(
state.validate_key("sk-admin").await,
Some(Permission::Admin)
);
assert_eq!(state.validate_key("sk-invalid").await, None);
}
#[tokio::test]
async fn test_auth_state_public_path() {
let config = AuthConfig::enabled().add_public_path("/custom/public");
let state = AuthState::new(config);
assert!(state.is_public_path("/health").await);
assert!(state.is_public_path("/ready").await);
assert!(state.is_public_path("/custom/public").await);
assert!(!state.is_public_path("/v1/chat/completions").await);
}
#[tokio::test]
async fn test_auth_state_disabled() {
let config = AuthConfig::disabled();
let state = AuthState::new(config);
assert!(state.is_public_path("/v1/chat/completions").await);
assert!(state.is_public_path("/api/models/load").await);
}
#[test]
fn test_requires_admin_permission() {
assert!(requires_admin_permission("/api/models/load"));
assert!(requires_admin_permission("/api/models/unload"));
assert!(requires_admin_permission("/api/keys"));
assert!(!requires_admin_permission("/v1/chat/completions"));
assert!(!requires_admin_permission("/v1/embeddings"));
assert!(!requires_admin_permission("/health"));
}
#[test]
fn test_permission_default() {
let perm = Permission::default();
assert_eq!(perm, Permission::User);
}
#[test]
fn test_scope_display() {
assert_eq!(Scope::Inference.as_str(), "inference");
assert_eq!(Scope::Admin.as_str(), "admin");
assert_eq!(Scope::Metrics.as_str(), "metrics");
}
#[test]
fn test_scope_from_str() {
assert_eq!(Scope::from_str("inference"), Some(Scope::Inference));
assert_eq!(Scope::from_str("inf"), Some(Scope::Inference));
assert_eq!(Scope::from_str("admin"), Some(Scope::Admin));
assert_eq!(Scope::from_str("adm"), Some(Scope::Admin));
assert_eq!(Scope::from_str("metrics"), Some(Scope::Metrics));
assert_eq!(Scope::from_str("met"), Some(Scope::Metrics));
assert_eq!(Scope::from_str("invalid"), None);
}
#[test]
fn test_api_key_with_scopes() {
let key = ApiKey::with_scopes("sk-inf-test123", vec![Scope::Inference]);
assert_eq!(key.key, "sk-inf-test123");
assert!(key.scopes.contains(&Scope::Inference));
assert!(!key.scopes.contains(&Scope::Admin));
}
#[test]
fn test_api_key_admin_has_all_scopes() {
let key = ApiKey::admin("sk-adm-test456");
assert!(key.scopes.contains(&Scope::Inference));
assert!(key.scopes.contains(&Scope::Admin));
assert!(key.scopes.contains(&Scope::Metrics));
}
#[test]
fn test_api_key_parse_from_format() {
let key_str = "sk-inf-abc123def456";
let scope = ApiKey::parse_scope_from_key(key_str);
assert_eq!(scope, Some(Scope::Inference));
let admin_key = "sk-adm-xyz789";
assert_eq!(ApiKey::parse_scope_from_key(admin_key), Some(Scope::Admin));
let metrics_key = "sk-met-qrs456";
assert_eq!(
ApiKey::parse_scope_from_key(metrics_key),
Some(Scope::Metrics)
);
let legacy_key = "sk-oldkey123";
assert_eq!(ApiKey::parse_scope_from_key(legacy_key), None);
}
#[test]
fn test_endpoint_scope_requirements() {
assert_eq!(
required_scope_for_path("/v1/chat/completions"),
Some(Scope::Inference)
);
assert_eq!(
required_scope_for_path("/v1/completions"),
Some(Scope::Inference)
);
assert_eq!(
required_scope_for_path("/v1/embeddings"),
Some(Scope::Inference)
);
assert_eq!(
required_scope_for_path("/v1/models"),
Some(Scope::Inference)
);
assert_eq!(
required_scope_for_path("/api/models/load"),
Some(Scope::Admin)
);
assert_eq!(
required_scope_for_path("/api/models/unload"),
Some(Scope::Admin)
);
assert_eq!(required_scope_for_path("/api/keys"), Some(Scope::Admin));
assert_eq!(required_scope_for_path("/api/config"), Some(Scope::Admin));
assert_eq!(
required_scope_for_path("/admin/models/load"),
Some(Scope::Admin)
);
assert_eq!(
required_scope_for_path("/admin/models/unload"),
Some(Scope::Admin)
);
assert_eq!(
required_scope_for_path("/admin/models/status"),
Some(Scope::Admin)
);
assert_eq!(
required_scope_for_path("/admin/models/warmup"),
Some(Scope::Admin)
);
assert_eq!(required_scope_for_path("/health"), None);
assert_eq!(required_scope_for_path("/ready"), None);
}
#[tokio::test]
async fn test_auth_state_validate_scopes() {
let config = AuthConfig::enabled()
.add_key(ApiKey::with_scopes("sk-inf-user", vec![Scope::Inference]))
.add_key(ApiKey::admin("sk-adm-admin"));
let state = AuthState::new(config);
assert!(state.has_scope("sk-inf-user", Scope::Inference).await);
assert!(!state.has_scope("sk-inf-user", Scope::Admin).await);
assert!(state.has_scope("sk-adm-admin", Scope::Inference).await);
assert!(state.has_scope("sk-adm-admin", Scope::Admin).await);
assert!(state.has_scope("sk-adm-admin", Scope::Metrics).await);
}
#[test]
fn test_api_key_hashed_storage() {
let key = ApiKey::with_scopes("sk-inf-secret123", vec![Scope::Inference]);
let hashed = key.hash_key();
assert_ne!(hashed, "sk-inf-secret123");
assert!(ApiKey::verify_key("sk-inf-secret123", &hashed));
assert!(!ApiKey::verify_key("wrong-key", &hashed));
}
#[test]
fn test_auth_metrics_new() {
let metrics = AuthMetrics::new();
assert_eq!(metrics.total_failures(), 0);
}
#[test]
fn test_auth_metrics_record_failures() {
let metrics = AuthMetrics::new();
metrics.record_failure(AuthFailureReason::MissingKey);
metrics.record_failure(AuthFailureReason::InvalidKey);
metrics.record_failure(AuthFailureReason::InvalidKey);
metrics.record_failure(AuthFailureReason::ExpiredKey);
metrics.record_failure(AuthFailureReason::DisabledKey);
metrics.record_failure(AuthFailureReason::InsufficientScope);
assert_eq!(metrics.total_failures(), 6);
assert_eq!(metrics.failures_missing_key.load(Ordering::Relaxed), 1);
assert_eq!(metrics.failures_invalid_key.load(Ordering::Relaxed), 2);
assert_eq!(metrics.failures_expired_key.load(Ordering::Relaxed), 1);
assert_eq!(metrics.failures_disabled_key.load(Ordering::Relaxed), 1);
assert_eq!(
metrics.failures_insufficient_scope.load(Ordering::Relaxed),
1
);
}
#[test]
fn test_auth_metrics_prometheus_format() {
let metrics = AuthMetrics::new();
metrics.record_failure(AuthFailureReason::InvalidKey);
metrics.record_failure(AuthFailureReason::InvalidKey);
metrics.record_failure(AuthFailureReason::MissingKey);
let output = metrics.render_prometheus();
assert!(output.contains("# HELP infernum_auth_failures_total"));
assert!(output.contains("# TYPE infernum_auth_failures_total counter"));
assert!(output.contains("infernum_auth_failures_total{reason=\"invalid_key\"} 2"));
assert!(output.contains("infernum_auth_failures_total{reason=\"missing_key\"} 1"));
assert!(output.contains("infernum_auth_failures_total{reason=\"expired_key\"} 0"));
}
#[test]
fn test_auth_state_has_metrics() {
let config = AuthConfig::enabled();
let state = AuthState::new(config);
assert_eq!(state.metrics().total_failures(), 0);
state
.metrics()
.record_failure(AuthFailureReason::InvalidKey);
assert_eq!(state.metrics().total_failures(), 1);
}
}