use std::sync::Arc;
use axum::{
body::Body,
extract::{Request, State},
http::{StatusCode, header},
middleware::Next,
response::{IntoResponse, Json, Response},
};
use forge_core::auth::Claims;
use forge_core::config::JwtAlgorithm as CoreJwtAlgorithm;
use forge_core::function::AuthContext;
use jsonwebtoken::{Algorithm, DecodingKey, Validation, dangerous, decode, encode};
use tracing::debug;
use super::jwks::JwksClient;
#[derive(Debug, Clone)]
pub struct AuthConfig {
pub jwt_secret: Option<String>,
pub algorithm: JwtAlgorithm,
pub jwks_client: Option<Arc<JwksClient>>,
pub issuer: Option<String>,
pub audience: Option<String>,
pub(crate) skip_verification: bool,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
jwt_secret: None,
algorithm: JwtAlgorithm::HS256,
jwks_client: None,
issuer: None,
audience: None,
skip_verification: false,
}
}
}
impl AuthConfig {
pub fn from_forge_config(
config: &forge_core::config::AuthConfig,
) -> Result<Self, super::jwks::JwksError> {
let algorithm = JwtAlgorithm::from(config.jwt_algorithm);
let jwks_client = config
.jwks_url
.as_ref()
.map(|url| JwksClient::new(url.clone(), config.jwks_cache_ttl_secs).map(Arc::new))
.transpose()?;
Ok(Self {
jwt_secret: config.jwt_secret.clone(),
algorithm,
jwks_client,
issuer: config.jwt_issuer.clone(),
audience: config.jwt_audience.clone(),
skip_verification: false,
})
}
pub fn with_secret(secret: impl Into<String>) -> Self {
Self {
jwt_secret: Some(secret.into()),
..Default::default()
}
}
pub fn dev_mode() -> Self {
if std::env::var("FORGE_ENV")
.map(|v| v.eq_ignore_ascii_case("production"))
.unwrap_or(false)
{
tracing::error!(
"AuthConfig::dev_mode() called with FORGE_ENV=production. \
Returning default config with verification enabled."
);
return Self::default();
}
Self {
jwt_secret: None,
algorithm: JwtAlgorithm::HS256,
jwks_client: None,
issuer: None,
audience: None,
skip_verification: true,
}
}
pub fn is_hmac(&self) -> bool {
matches!(
self.algorithm,
JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
)
}
pub fn is_rsa(&self) -> bool {
matches!(
self.algorithm,
JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum JwtAlgorithm {
#[default]
HS256,
HS384,
HS512,
RS256,
RS384,
RS512,
}
impl From<JwtAlgorithm> for Algorithm {
fn from(alg: JwtAlgorithm) -> Self {
match alg {
JwtAlgorithm::HS256 => Algorithm::HS256,
JwtAlgorithm::HS384 => Algorithm::HS384,
JwtAlgorithm::HS512 => Algorithm::HS512,
JwtAlgorithm::RS256 => Algorithm::RS256,
JwtAlgorithm::RS384 => Algorithm::RS384,
JwtAlgorithm::RS512 => Algorithm::RS512,
}
}
}
impl From<CoreJwtAlgorithm> for JwtAlgorithm {
fn from(alg: CoreJwtAlgorithm) -> Self {
match alg {
CoreJwtAlgorithm::HS256 => JwtAlgorithm::HS256,
CoreJwtAlgorithm::HS384 => JwtAlgorithm::HS384,
CoreJwtAlgorithm::HS512 => JwtAlgorithm::HS512,
CoreJwtAlgorithm::RS256 => JwtAlgorithm::RS256,
CoreJwtAlgorithm::RS384 => JwtAlgorithm::RS384,
CoreJwtAlgorithm::RS512 => JwtAlgorithm::RS512,
}
}
}
#[derive(Clone)]
pub struct HmacTokenIssuer {
secret: String,
algorithm: Algorithm,
}
impl HmacTokenIssuer {
pub fn from_config(config: &AuthConfig) -> Option<Self> {
if !config.is_hmac() {
return None;
}
let secret = config.jwt_secret.as_ref()?.clone();
if secret.is_empty() {
return None;
}
if secret.len() < 32 {
tracing::warn!(
secret_len = secret.len(),
"JWT secret is shorter than 32 bytes. This weakens HMAC security \
and may allow brute-force attacks. Use a cryptographically random \
secret of at least 32 bytes (e.g. `openssl rand -base64 32`)."
);
}
Some(Self {
secret,
algorithm: config.algorithm.into(),
})
}
}
impl forge_core::TokenIssuer for HmacTokenIssuer {
fn sign(&self, claims: &Claims) -> forge_core::Result<String> {
let header = jsonwebtoken::Header::new(self.algorithm);
encode(
&header,
claims,
&jsonwebtoken::EncodingKey::from_secret(self.secret.as_bytes()),
)
.map_err(|e| forge_core::ForgeError::Internal(format!("token signing error: {e}")))
}
}
#[derive(Clone)]
pub struct AuthMiddleware {
config: Arc<AuthConfig>,
hmac_key: Option<DecodingKey>,
}
impl std::fmt::Debug for AuthMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthMiddleware")
.field("config", &self.config)
.field("hmac_key", &self.hmac_key.is_some())
.finish()
}
}
impl AuthMiddleware {
pub fn new(config: AuthConfig) -> Self {
if config.skip_verification {
tracing::warn!("JWT signature verification is DISABLED. Do not use in production.");
}
let hmac_key = if config.skip_verification {
None
} else if config.is_hmac() {
config
.jwt_secret
.as_ref()
.filter(|s| !s.is_empty())
.map(|secret| DecodingKey::from_secret(secret.as_bytes()))
} else {
None
};
Self {
config: Arc::new(config),
hmac_key,
}
}
pub fn permissive() -> Self {
Self::new(AuthConfig::dev_mode())
}
pub fn config(&self) -> &AuthConfig {
&self.config
}
pub async fn validate_token_async(&self, token: &str) -> Result<Claims, AuthError> {
if self.config.skip_verification {
return self.decode_without_verification(token);
}
if self.config.is_hmac() {
self.validate_hmac(token)
} else {
self.validate_rsa(token).await
}
}
fn validate_hmac(&self, token: &str) -> Result<Claims, AuthError> {
let key = self.hmac_key.as_ref().ok_or_else(|| {
AuthError::InvalidToken("JWT secret not configured for HMAC".to_string())
})?;
self.decode_and_validate(token, key)
}
async fn validate_rsa(&self, token: &str) -> Result<Claims, AuthError> {
let jwks = self.config.jwks_client.as_ref().ok_or_else(|| {
AuthError::InvalidToken("JWKS URL not configured for RSA".to_string())
})?;
let header = jsonwebtoken::decode_header(token)
.map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
debug!(kid = ?header.kid, alg = ?header.alg, "Validating RSA token");
let key = if let Some(kid) = header.kid {
jwks.get_key(&kid).await.map_err(|e| {
AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e))
})?
} else {
jwks.get_any_key()
.await
.map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?
};
self.decode_and_validate(token, &key)
}
fn decode_and_validate(&self, token: &str, key: &DecodingKey) -> Result<Claims, AuthError> {
let mut validation = Validation::new(self.config.algorithm.into());
validation.validate_exp = true;
validation.validate_nbf = true;
validation.leeway = 60;
validation.set_required_spec_claims(&["exp", "sub"]);
if let Some(ref issuer) = self.config.issuer {
validation.set_issuer(&[issuer]);
}
if let Some(ref audience) = self.config.audience {
validation.set_audience(&[audience]);
} else {
validation.validate_aud = false;
}
let token_data =
decode::<Claims>(token, key, &validation).map_err(|e| self.map_jwt_error(e))?;
Ok(token_data.claims)
}
fn map_jwt_error(&self, e: jsonwebtoken::errors::Error) -> AuthError {
match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
jsonwebtoken::errors::ErrorKind::InvalidSignature => {
AuthError::InvalidToken("Invalid signature".to_string())
}
jsonwebtoken::errors::ErrorKind::InvalidToken => {
AuthError::InvalidToken("Invalid token format".to_string())
}
jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
AuthError::InvalidToken(format!("Missing required claim: {}", claim))
}
jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
AuthError::InvalidToken("Invalid issuer".to_string())
}
jsonwebtoken::errors::ErrorKind::InvalidAudience => {
AuthError::InvalidToken("Invalid audience".to_string())
}
_ => AuthError::InvalidToken(e.to_string()),
}
}
fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
let token_data =
dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
jsonwebtoken::errors::ErrorKind::InvalidToken => {
AuthError::InvalidToken("Invalid token format".to_string())
}
_ => AuthError::InvalidToken(e.to_string()),
})?;
if token_data.claims.is_expired() {
return Err(AuthError::TokenExpired);
}
Ok(token_data.claims)
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum AuthError {
#[error("Missing authorization header")]
MissingHeader,
#[error("Invalid authorization header format")]
InvalidHeader,
#[error("Invalid token: {0}")]
InvalidToken(String),
#[error("Token expired")]
TokenExpired,
}
fn extract_auth_diag(req: &Request<Body>) -> (Option<String>, Option<String>) {
let ip = crate::gateway::extract_client_ip(req.headers());
let ua = crate::gateway::extract_header(req.headers(), "user-agent");
(ip, ua)
}
fn emit_auth_failure(
reason: &str,
detail: &str,
path: &str,
client_ip: Option<String>,
user_agent: Option<String>,
) {
let is_bot = crate::signals::bot::is_bot(user_agent.as_deref());
crate::signals::emit_diagnostic(
"auth.failed",
serde_json::json!({
"reason": reason,
"detail": detail,
"path": path,
}),
client_ip,
user_agent,
None,
None,
is_bot,
);
}
pub fn extract_token(req: &Request<Body>) -> Result<Option<String>, AuthError> {
let Some(header_value) = req.headers().get(axum::http::header::AUTHORIZATION) else {
return Ok(None);
};
let header = header_value
.to_str()
.map_err(|_| AuthError::InvalidHeader)?;
let token = header
.strip_prefix("Bearer ")
.ok_or(AuthError::InvalidHeader)?
.trim();
if token.is_empty() {
return Err(AuthError::InvalidHeader);
}
Ok(Some(token.to_string()))
}
pub async fn extract_auth_context_async(
token: Option<String>,
middleware: &AuthMiddleware,
) -> Result<AuthContext, AuthError> {
match token {
Some(token) => middleware
.validate_token_async(&token)
.await
.map(build_auth_context_from_claims),
None => Ok(AuthContext::unauthenticated()),
}
}
pub fn build_auth_context_from_claims(claims: Claims) -> AuthContext {
let user_id = claims.user_id();
let mut custom_claims = claims.sanitized_custom();
custom_claims.insert("sub".to_string(), serde_json::Value::String(claims.sub));
match user_id {
Some(uuid) => {
AuthContext::authenticated(uuid, claims.roles, custom_claims)
}
None => {
AuthContext::authenticated_without_uuid(claims.roles, custom_claims)
}
}
}
pub async fn auth_middleware(
State(middleware): State<Arc<AuthMiddleware>>,
req: Request<Body>,
next: Next,
) -> Response {
let token = match extract_token(&req) {
Ok(token) => token,
Err(e) => {
let (ip, ua) = extract_auth_diag(&req);
tracing::warn!(error = %e, "Invalid authorization header");
emit_auth_failure("invalid_header", &e.to_string(), req.uri().path(), ip, ua);
return (
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"success": false,
"error": { "code": "UNAUTHORIZED", "message": "Invalid authorization header" }
})),
)
.into_response();
}
};
tracing::trace!(
token_present = token.is_some(),
"Auth middleware processing request"
);
let auth_context = match extract_auth_context_async(token, &middleware).await {
Ok(auth_context) => auth_context,
Err(e) => {
let (ip, ua) = extract_auth_diag(&req);
let reason = match &e {
AuthError::TokenExpired => "token_expired",
AuthError::InvalidToken(_) => "invalid_token",
AuthError::MissingHeader => "missing_token",
AuthError::InvalidHeader => "invalid_header",
};
tracing::warn!(error = %e, "Token validation failed");
emit_auth_failure(reason, &e.to_string(), req.uri().path(), ip, ua);
return (
StatusCode::UNAUTHORIZED,
Json(serde_json::json!({
"success": false,
"error": { "code": "UNAUTHORIZED", "message": "Invalid authentication token" }
})),
)
.into_response();
}
};
tracing::trace!(
authenticated = auth_context.is_authenticated(),
"Auth context created"
);
let should_set_cookie =
auth_context.is_authenticated() && middleware.config.jwt_secret.is_some();
let req_is_https = req
.headers()
.get("x-forwarded-proto")
.and_then(|v| v.to_str().ok())
.map(|s| s == "https")
.unwrap_or(false);
let has_session_cookie = req
.headers()
.get(header::COOKIE)
.and_then(|v| v.to_str().ok())
.map(|c| c.contains("forge_session="))
.unwrap_or(false);
let should_set_cookie = should_set_cookie && !has_session_cookie;
let mut req = req;
req.extensions_mut().insert(auth_context.clone());
let mut response = next.run(req).await;
if should_set_cookie
&& let Some(subject) = auth_context.subject()
&& let Some(secret) = &middleware.config.jwt_secret
{
let cookie_value = sign_session_cookie(subject, secret);
let secure_flag = if req_is_https { "; Secure" } else { "" };
let cookie = format!(
"forge_session={cookie_value}; Path=/_api/oauth/; HttpOnly; SameSite=Lax; Max-Age=86400{secure_flag}"
);
if let Ok(val) = axum::http::HeaderValue::from_str(&cookie) {
response.headers_mut().append(header::SET_COOKIE, val);
}
}
response
}
pub fn sign_session_cookie(subject: &str, secret: &str) -> String {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use hmac::{Hmac, Mac};
use sha2::Sha256;
let expiry = chrono::Utc::now().timestamp() + 86400; let payload = format!("{subject}.{expiry}");
let mut mac =
Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key length");
mac.update(payload.as_bytes());
let sig = URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes());
format!("{payload}.{sig}")
}
pub fn verify_session_cookie(cookie_value: &str, secret: &str) -> Option<String> {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use hmac::{Hmac, Mac};
use sha2::Sha256;
let parts: Vec<&str> = cookie_value.rsplitn(2, '.').collect();
if parts.len() != 2 {
return None;
}
let sig_encoded = parts.first()?;
let payload = parts.get(1)?;
let sig_bytes = URL_SAFE_NO_PAD.decode(sig_encoded).ok()?;
let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).ok()?;
mac.update(payload.as_bytes());
mac.verify_slice(&sig_bytes).ok()?;
let dot_pos = payload.rfind('.')?;
let subject = &payload[..dot_pos];
let expiry_str = &payload[dot_pos + 1..];
let expiry: i64 = expiry_str.parse().ok()?;
if chrono::Utc::now().timestamp() > expiry {
return None;
}
Some(subject.to_string())
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
mod tests {
use super::*;
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use hmac::{Hmac, Mac};
use jsonwebtoken::{EncodingKey, Header, encode};
use sha2::Sha256;
fn create_test_claims(expired: bool) -> Claims {
use forge_core::auth::ClaimsBuilder;
let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
if expired {
builder = builder.duration_secs(-3600); } else {
builder = builder.duration_secs(3600); }
builder.build().unwrap()
}
fn create_test_token(claims: &Claims, secret: &str) -> String {
encode(
&Header::default(),
claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap()
}
fn session_cookie_with_expiry(subject: &str, secret: &str, expiry: i64) -> String {
let payload = format!("{subject}.{expiry}");
let mut mac =
Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key");
mac.update(payload.as_bytes());
let sig = URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes());
format!("{payload}.{sig}")
}
#[test]
fn test_auth_config_default() {
let config = AuthConfig::default();
assert_eq!(config.algorithm, JwtAlgorithm::HS256);
assert!(!config.skip_verification);
}
#[test]
fn test_auth_config_dev_mode() {
let config = AuthConfig::dev_mode();
assert!(config.skip_verification);
}
#[test]
fn test_auth_middleware_permissive() {
let middleware = AuthMiddleware::permissive();
assert!(middleware.config.skip_verification);
}
#[tokio::test]
async fn test_valid_token_with_correct_secret() {
let secret = "test-secret-key";
let config = AuthConfig::with_secret(secret);
let middleware = AuthMiddleware::new(config);
let claims = create_test_claims(false);
let token = create_test_token(&claims, secret);
let result = middleware.validate_token_async(&token).await;
assert!(result.is_ok());
let validated_claims = result.unwrap();
assert_eq!(validated_claims.sub, "test-user-id");
}
#[tokio::test]
async fn test_valid_token_with_wrong_secret() {
let config = AuthConfig::with_secret("correct-secret");
let middleware = AuthMiddleware::new(config);
let claims = create_test_claims(false);
let token = create_test_token(&claims, "wrong-secret");
let result = middleware.validate_token_async(&token).await;
assert!(result.is_err());
match result {
Err(AuthError::InvalidToken(_)) => {}
_ => panic!("Expected InvalidToken error"),
}
}
#[tokio::test]
async fn test_expired_token() {
let secret = "test-secret";
let config = AuthConfig::with_secret(secret);
let middleware = AuthMiddleware::new(config);
let claims = create_test_claims(true); let token = create_test_token(&claims, secret);
let result = middleware.validate_token_async(&token).await;
assert!(result.is_err());
match result {
Err(AuthError::TokenExpired) => {}
_ => panic!("Expected TokenExpired error"),
}
}
#[tokio::test]
async fn test_tampered_token() {
let secret = "test-secret";
let config = AuthConfig::with_secret(secret);
let middleware = AuthMiddleware::new(config);
let claims = create_test_claims(false);
let mut token = create_test_token(&claims, secret);
if let Some(last_char) = token.pop() {
let replacement = if last_char == 'a' { 'b' } else { 'a' };
token.push(replacement);
}
let result = middleware.validate_token_async(&token).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_dev_mode_skips_signature() {
let config = AuthConfig::dev_mode();
let middleware = AuthMiddleware::new(config);
let claims = create_test_claims(false);
let token = create_test_token(&claims, "any-secret");
let result = middleware.validate_token_async(&token).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_dev_mode_still_checks_expiration() {
let config = AuthConfig::dev_mode();
let middleware = AuthMiddleware::new(config);
let claims = create_test_claims(true); let token = create_test_token(&claims, "any-secret");
let result = middleware.validate_token_async(&token).await;
assert!(result.is_err());
match result {
Err(AuthError::TokenExpired) => {}
_ => panic!("Expected TokenExpired error even in dev mode"),
}
}
#[tokio::test]
async fn test_invalid_token_format() {
let config = AuthConfig::with_secret("secret");
let middleware = AuthMiddleware::new(config);
let result = middleware.validate_token_async("not-a-valid-jwt").await;
assert!(result.is_err());
match result {
Err(AuthError::InvalidToken(_)) => {}
_ => panic!("Expected InvalidToken error"),
}
}
#[test]
fn test_algorithm_conversion() {
assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
assert_eq!(Algorithm::from(JwtAlgorithm::HS384), Algorithm::HS384);
assert_eq!(Algorithm::from(JwtAlgorithm::HS512), Algorithm::HS512);
assert_eq!(Algorithm::from(JwtAlgorithm::RS256), Algorithm::RS256);
assert_eq!(Algorithm::from(JwtAlgorithm::RS384), Algorithm::RS384);
assert_eq!(Algorithm::from(JwtAlgorithm::RS512), Algorithm::RS512);
}
#[test]
fn test_is_hmac_and_is_rsa() {
let hmac_config = AuthConfig::with_secret("test");
assert!(hmac_config.is_hmac());
assert!(!hmac_config.is_rsa());
let rsa_config = AuthConfig {
algorithm: JwtAlgorithm::RS256,
..Default::default()
};
assert!(!rsa_config.is_hmac());
assert!(rsa_config.is_rsa());
}
#[test]
fn test_extract_token_rejects_non_bearer_header() {
let req = Request::builder()
.header(axum::http::header::AUTHORIZATION, "Basic abc")
.body(Body::empty())
.unwrap();
let result = extract_token(&req);
assert!(matches!(result, Err(AuthError::InvalidHeader)));
}
#[test]
fn test_build_auth_context_from_non_uuid_claims_preserves_subject() {
let claims = Claims::builder()
.subject("clerk_user_123")
.role("member")
.claim("tenant_id", serde_json::json!("tenant-1"))
.build()
.unwrap();
let auth = build_auth_context_from_claims(claims);
assert!(auth.is_authenticated());
assert!(auth.user_id().is_none());
assert_eq!(auth.subject(), Some("clerk_user_123"));
assert_eq!(auth.principal_id(), Some("clerk_user_123".to_string()));
assert!(auth.has_role("member"));
assert_eq!(
auth.claim("sub"),
Some(&serde_json::json!("clerk_user_123"))
);
}
#[test]
fn test_verify_session_cookie_round_trip_and_tamper_detection() {
let cookie = sign_session_cookie("user-123", "session-secret");
assert_eq!(
verify_session_cookie(&cookie, "session-secret"),
Some("user-123".to_string())
);
let mut tampered = cookie.clone();
if let Some(last_char) = tampered.pop() {
tampered.push(if last_char == 'a' { 'b' } else { 'a' });
}
assert_eq!(verify_session_cookie(&tampered, "session-secret"), None);
assert_eq!(verify_session_cookie(&cookie, "wrong-secret"), None);
}
#[test]
fn test_verify_session_cookie_rejects_expired_cookie() {
let expired_cookie = session_cookie_with_expiry(
"user-123",
"session-secret",
chrono::Utc::now().timestamp() - 1,
);
assert_eq!(
verify_session_cookie(&expired_cookie, "session-secret"),
None
);
}
#[tokio::test]
async fn test_extract_auth_context_async_invalid_token_errors() {
let middleware = AuthMiddleware::new(AuthConfig::with_secret("secret"));
let result = extract_auth_context_async(Some("bad.token".to_string()), &middleware).await;
assert!(matches!(result, Err(AuthError::InvalidToken(_))));
}
}