use std::sync::Arc;
use axum::{
body::Body,
extract::{Request, State},
http::header,
middleware::Next,
response::{IntoResponse, Response},
};
use forge_core::auth::Claims;
use forge_core::config::JwtAlgorithm as CoreJwtAlgorithm;
use forge_core::function::AuthContext;
use jsonwebtoken::{Algorithm, DecodingKey, Validation, dangerous, decode, encode};
use sha2::{Digest, Sha256};
use tracing::debug;
use super::jwks::JwksClient;
fn secret_kid(secret: &[u8]) -> String {
let hash = Sha256::digest(secret);
let prefix = hash.as_slice().get(..4).unwrap_or(&[]);
let mut out = String::with_capacity(prefix.len() * 2);
for byte in prefix {
use std::fmt::Write;
let _ = write!(out, "{byte:02x}");
}
out
}
fn sanitize_for_log(s: &str) -> String {
s.chars()
.filter(|c| !c.is_ascii_control())
.take(64)
.collect()
}
#[derive(Default)]
struct DevModeEnv {
forge_env: Option<String>,
node_env: Option<String>,
railway_environment: Option<String>,
k_service: Option<String>,
fly_app_name: Option<String>,
kubernetes_service_host: Option<String>,
aws_execution_env: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AuthMode {
#[default]
Production,
Development,
}
#[derive(Debug, Clone)]
pub struct AuthConfig {
pub jwt_secret: Option<String>,
pub algorithm: JwtAlgorithm,
pub jwks_client: Option<Arc<JwksClient>>,
pub issuer: Option<String>,
pub audience: Option<String>,
pub leeway_secs: u64,
pub session_cookie_ttl_secs: i64,
pub legacy_secrets: Vec<forge_core::config::LegacySecret>,
pub required_claims: Vec<String>,
pub jwks_require_kid: bool,
pub(crate) mode: AuthMode,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
jwt_secret: None,
algorithm: JwtAlgorithm::HS256,
jwks_client: None,
issuer: None,
audience: None,
leeway_secs: 60,
session_cookie_ttl_secs: 3600,
legacy_secrets: Vec::new(),
required_claims: vec!["exp".into(), "sub".into()],
jwks_require_kid: true,
mode: AuthMode::Production,
}
}
}
impl AuthConfig {
pub fn from_forge_config(
config: &forge_core::config::AuthConfig,
) -> Result<Self, super::jwks::JwksError> {
let algorithm = JwtAlgorithm::from(config.jwt_algorithm);
let jwks_client = config
.jwks_url
.as_ref()
.map(|url| JwksClient::new(url.clone(), config.jwks_cache_ttl.as_secs()).map(Arc::new))
.transpose()?;
Ok(Self {
jwt_secret: config.jwt_secret.clone(),
algorithm,
jwks_client,
issuer: config.jwt_issuer.clone(),
audience: config.jwt_audience.clone(),
leeway_secs: config.jwt_leeway.as_secs(),
session_cookie_ttl_secs: config.session_cookie_ttl_secs(),
legacy_secrets: config.legacy_secrets.clone(),
required_claims: config.required_claims.clone(),
jwks_require_kid: config.jwks_require_kid,
mode: AuthMode::Production,
})
}
pub fn with_secret(secret: impl Into<String>) -> Self {
Self {
jwt_secret: Some(secret.into()),
..Default::default()
}
}
pub fn dev_mode() -> forge_core::Result<Self> {
Self::dev_mode_with_env(DevModeEnv {
forge_env: std::env::var("FORGE_ENV").ok(),
node_env: std::env::var("NODE_ENV").ok(),
railway_environment: std::env::var("RAILWAY_ENVIRONMENT").ok(),
k_service: std::env::var("K_SERVICE").ok(),
fly_app_name: std::env::var("FLY_APP_NAME").ok(),
kubernetes_service_host: std::env::var("KUBERNETES_SERVICE_HOST").ok(),
aws_execution_env: std::env::var("AWS_EXECUTION_ENV").ok(),
})
}
fn dev_mode_with_env(env: DevModeEnv) -> forge_core::Result<Self> {
if env
.forge_env
.as_deref()
.is_some_and(|v| v.eq_ignore_ascii_case("production"))
{
return Err(forge_core::ForgeError::config(
"AuthConfig::dev_mode() refused: FORGE_ENV=production. \
Configure a real jwt_secret or jwks_url instead.",
));
}
if env
.node_env
.as_deref()
.is_some_and(|v| v.eq_ignore_ascii_case("production"))
{
return Err(forge_core::ForgeError::config(
"AuthConfig::dev_mode() refused: NODE_ENV=production detected. \
Configure a real jwt_secret or jwks_url instead.",
));
}
let indicators = [
("RAILWAY_ENVIRONMENT", &env.railway_environment),
("K_SERVICE", &env.k_service),
("FLY_APP_NAME", &env.fly_app_name),
("KUBERNETES_SERVICE_HOST", &env.kubernetes_service_host),
("AWS_EXECUTION_ENV", &env.aws_execution_env),
];
for (name, val) in &indicators {
if val.is_some() {
return Err(forge_core::ForgeError::config(format!(
"AuthConfig::dev_mode() refused: {name} is set, indicating a production \
environment. Configure a real jwt_secret or jwks_url instead."
)));
}
}
Ok(Self {
jwt_secret: None,
algorithm: JwtAlgorithm::HS256,
jwks_client: None,
issuer: None,
audience: None,
leeway_secs: 60,
session_cookie_ttl_secs: 3600,
legacy_secrets: Vec::new(),
required_claims: vec!["exp".into(), "sub".into()],
jwks_require_kid: true,
mode: AuthMode::Development,
})
}
pub fn is_hmac(&self) -> bool {
matches!(self.algorithm, JwtAlgorithm::HS256)
}
pub fn is_rsa(&self) -> bool {
matches!(self.algorithm, JwtAlgorithm::RS256)
}
pub fn skips_verification(&self) -> bool {
matches!(self.mode, AuthMode::Development)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum JwtAlgorithm {
#[default]
HS256,
RS256,
}
impl From<JwtAlgorithm> for Algorithm {
fn from(alg: JwtAlgorithm) -> Self {
match alg {
JwtAlgorithm::HS256 => Algorithm::HS256,
JwtAlgorithm::RS256 => Algorithm::RS256,
}
}
}
impl From<CoreJwtAlgorithm> for JwtAlgorithm {
fn from(alg: CoreJwtAlgorithm) -> Self {
match alg {
CoreJwtAlgorithm::HS256 => JwtAlgorithm::HS256,
CoreJwtAlgorithm::RS256 => JwtAlgorithm::RS256,
_ => {
tracing::error!(
"Unknown CoreJwtAlgorithm variant; falling back to HS256. \
Update forge-runtime to support this algorithm."
);
JwtAlgorithm::HS256
}
}
}
}
#[derive(Clone)]
pub struct HmacTokenIssuer {
secret: String,
kid: String,
algorithm: Algorithm,
}
impl HmacTokenIssuer {
pub fn from_config(config: &AuthConfig) -> Option<Self> {
if !config.is_hmac() {
return None;
}
let secret = config.jwt_secret.as_ref()?.clone();
if secret.is_empty() {
return None;
}
if secret.len() < 32 {
tracing::warn!(
secret_len = secret.len(),
"JWT secret is shorter than 32 bytes; startup validation should have caught this"
);
}
let kid = secret_kid(secret.as_bytes());
Some(Self {
secret,
kid,
algorithm: config.algorithm.into(),
})
}
}
impl forge_core::TokenIssuer for HmacTokenIssuer {
fn sign(&self, claims: &Claims) -> forge_core::Result<String> {
let mut header = jsonwebtoken::Header::new(self.algorithm);
header.kid = Some(self.kid.clone());
encode(
&header,
claims,
&jsonwebtoken::EncodingKey::from_secret(self.secret.as_bytes()),
)
.map_err(|e| forge_core::ForgeError::internal_with("token signing error", e))
}
}
#[derive(Clone)]
pub struct AuthMiddleware {
config: Arc<AuthConfig>,
hmac_key: Option<DecodingKey>,
hmac_kid: Option<String>,
legacy_hmac_keys: Vec<(String, DecodingKey)>,
token_cache: Arc<dashmap::DashMap<[u8; 32], (Claims, std::time::Instant)>>,
}
impl std::fmt::Debug for AuthMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthMiddleware")
.field("config", &self.config)
.field("hmac_key", &self.hmac_key.is_some())
.field("hmac_kid", &self.hmac_kid)
.field("legacy_hmac_keys", &self.legacy_hmac_keys.len())
.finish()
}
}
impl AuthMiddleware {
pub fn new(config: AuthConfig) -> Self {
if config.skips_verification() {
tracing::warn!("JWT signature verification is DISABLED. Do not use in production.");
}
let active_secret = if !config.skips_verification() && config.is_hmac() {
config.jwt_secret.as_deref().filter(|s| !s.is_empty())
} else {
None
};
let hmac_key = active_secret.map(|s| DecodingKey::from_secret(s.as_bytes()));
let hmac_kid = active_secret.map(|s| secret_kid(s.as_bytes()));
let legacy_hmac_keys = if config.is_hmac() && !config.skips_verification() {
let now = chrono::Utc::now();
config
.legacy_secrets
.iter()
.filter(|ls| {
if ls.secret.is_empty() {
return false;
}
if ls.valid_until <= now {
tracing::warn!(
valid_until = %ls.valid_until,
"Legacy JWT secret is expired and will not be used for verification"
);
return false;
}
true
})
.map(|ls| {
(
secret_kid(ls.secret.as_bytes()),
DecodingKey::from_secret(ls.secret.as_bytes()),
)
})
.collect()
} else {
Vec::new()
};
Self {
config: Arc::new(config),
hmac_key,
hmac_kid,
legacy_hmac_keys,
token_cache: Arc::new(dashmap::DashMap::new()),
}
}
pub fn permissive() -> forge_core::Result<Self> {
Ok(Self::new(AuthConfig::dev_mode()?))
}
pub fn config(&self) -> &AuthConfig {
&self.config
}
pub async fn validate_token_async(&self, token: &str) -> Result<Claims, AuthError> {
if self.config.skips_verification() {
return self.decode_without_verification(token);
}
let token_hash = Self::hash_token(token);
if let Some(entry) = self.token_cache.get(&token_hash) {
let (claims, expires_at) = entry.value();
if std::time::Instant::now() < *expires_at {
return Ok(claims.clone());
}
drop(entry);
self.token_cache.remove(&token_hash);
}
let claims = if self.config.is_hmac() {
self.validate_hmac(token)?
} else {
self.validate_rsa(token).await?
};
let cache_ttl = Self::cache_ttl(&claims);
if cache_ttl > std::time::Duration::ZERO {
self.token_cache.insert(
token_hash,
(claims.clone(), std::time::Instant::now() + cache_ttl),
);
}
self.evict_expired_cache_entries();
Ok(claims)
}
fn hash_token(token: &str) -> [u8; 32] {
Sha256::digest(token.as_bytes()).into()
}
fn cache_ttl(claims: &Claims) -> std::time::Duration {
const MAX_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(60);
let exp = claims.exp();
let now = chrono::Utc::now().timestamp();
let remaining = if exp > now {
std::time::Duration::from_secs((exp - now) as u64)
} else {
std::time::Duration::ZERO
};
remaining.min(MAX_CACHE_TTL)
}
fn evict_expired_cache_entries(&self) {
const MAX_CACHE_SIZE: usize = 10_000;
if self.token_cache.len() > MAX_CACHE_SIZE {
let now = std::time::Instant::now();
self.token_cache
.retain(|_, (_, expires_at)| *expires_at > now);
}
}
fn validate_hmac(&self, token: &str) -> Result<Claims, AuthError> {
let primary = self.hmac_key.as_ref().ok_or_else(|| {
AuthError::InvalidToken("JWT secret not configured for HMAC".to_string())
})?;
let token_kid = jsonwebtoken::decode_header(token).ok().and_then(|h| h.kid);
if let Some(tkid) = token_kid.as_deref() {
if self.hmac_kid.as_deref() == Some(tkid) {
return self.decode_and_validate(token, primary);
}
for (kid, key) in &self.legacy_hmac_keys {
if kid == tkid {
return self.decode_and_validate(token, key);
}
}
debug!(kid = %sanitize_for_log(tkid), "Token kid not recognised; falling back to full key scan");
}
match self.decode_and_validate(token, primary) {
Ok(claims) => Ok(claims),
Err(AuthError::InvalidToken(_)) if !self.legacy_hmac_keys.is_empty() => {
for (_, key) in &self.legacy_hmac_keys {
if let Ok(claims) = self.decode_and_validate(token, key) {
return Ok(claims);
}
}
Err(AuthError::InvalidToken("Invalid signature".to_string()))
}
Err(e) => Err(e),
}
}
async fn validate_rsa(&self, token: &str) -> Result<Claims, AuthError> {
let jwks = self.config.jwks_client.as_ref().ok_or_else(|| {
AuthError::InvalidToken("JWKS URL not configured for RSA".to_string())
})?;
let header = jsonwebtoken::decode_header(token)
.map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
let safe_kid = header.kid.as_deref().map(sanitize_for_log);
debug!(kid = ?safe_kid, alg = ?header.alg, "Validating RSA token");
let key = if let Some(kid) = header.kid {
jwks.get_key(&kid).await.map_err(|e| {
AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e))
})?
} else if self.config.jwks_require_kid {
return Err(AuthError::InvalidToken(
"RS256 token missing kid header; set auth.jwks_require_kid = false to allow kidless tokens".to_string(),
));
} else {
jwks.get_any_key()
.await
.map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?
};
self.decode_and_validate(token, &key)
}
fn decode_and_validate(&self, token: &str, key: &DecodingKey) -> Result<Claims, AuthError> {
let header = jsonwebtoken::decode_header(token)
.map_err(|e| AuthError::InvalidToken(format!("Invalid token header: {}", e)))?;
let expected: jsonwebtoken::Algorithm = self.config.algorithm.into();
if header.alg != expected {
return Err(AuthError::InvalidToken(format!(
"Token algorithm {:?} does not match configured {:?}",
header.alg, expected
)));
}
let mut validation = Validation::new(self.config.algorithm.into());
validation.validate_exp = true;
validation.validate_nbf = true;
validation.leeway = self.config.leeway_secs;
let required: Vec<&str> = self
.config
.required_claims
.iter()
.map(String::as_str)
.collect();
validation.set_required_spec_claims(&required);
if let Some(ref issuer) = self.config.issuer {
validation.set_issuer(&[issuer]);
}
if let Some(ref audience) = self.config.audience {
validation.set_audience(&[audience]);
} else {
validation.validate_aud = false;
}
let token_data =
decode::<Claims>(token, key, &validation).map_err(|e| self.map_jwt_error(e))?;
Ok(token_data.claims)
}
fn map_jwt_error(&self, e: jsonwebtoken::errors::Error) -> AuthError {
match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::TokenExpired,
jsonwebtoken::errors::ErrorKind::InvalidSignature => {
AuthError::InvalidToken("Invalid signature".to_string())
}
jsonwebtoken::errors::ErrorKind::InvalidToken => {
AuthError::InvalidToken("Invalid token format".to_string())
}
jsonwebtoken::errors::ErrorKind::MissingRequiredClaim(claim) => {
AuthError::InvalidToken(format!("Missing required claim: {}", claim))
}
jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
AuthError::InvalidToken("Invalid issuer".to_string())
}
jsonwebtoken::errors::ErrorKind::InvalidAudience => {
AuthError::InvalidToken("Invalid audience".to_string())
}
_ => AuthError::InvalidToken(e.to_string()),
}
}
fn decode_without_verification(&self, token: &str) -> Result<Claims, AuthError> {
let token_data =
dangerous::insecure_decode::<Claims>(token).map_err(|e| match e.kind() {
jsonwebtoken::errors::ErrorKind::InvalidToken => {
AuthError::InvalidToken("Invalid token format".to_string())
}
_ => AuthError::InvalidToken(e.to_string()),
})?;
if token_data.claims.is_expired() {
return Err(AuthError::TokenExpired);
}
Ok(token_data.claims)
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum AuthError {
#[error("Missing authorization header")]
MissingHeader,
#[error("Invalid authorization header format")]
InvalidHeader,
#[error("Invalid token: {0}")]
InvalidToken(String),
#[error("Token expired")]
TokenExpired,
}
fn extract_auth_diag(req: &Request<Body>) -> (Option<String>, Option<String>) {
let ip = req
.extensions()
.get::<crate::gateway::ResolvedClientIp>()
.and_then(|r| r.0.clone());
let ua = crate::gateway::extract_header(req.headers(), "user-agent");
(ip, ua)
}
fn emit_auth_failure(
reason: &str,
detail: &str,
path: &str,
client_ip: Option<String>,
user_agent: Option<String>,
) {
let is_bot = crate::signals::bot::is_bot(user_agent.as_deref());
crate::signals::emit_diagnostic(
"auth.failed",
serde_json::json!({
"reason": reason,
"detail": detail,
"path": path,
}),
client_ip,
user_agent,
None,
None,
is_bot,
);
}
pub fn extract_token(req: &Request<Body>) -> Result<Option<String>, AuthError> {
let Some(header_value) = req.headers().get(axum::http::header::AUTHORIZATION) else {
return Ok(None);
};
let header = header_value
.to_str()
.map_err(|_| AuthError::InvalidHeader)?;
let token = header
.strip_prefix("Bearer ")
.ok_or(AuthError::InvalidHeader)?
.trim();
if token.is_empty() {
return Err(AuthError::InvalidHeader);
}
Ok(Some(token.to_string()))
}
pub async fn extract_auth_context_async(
token: Option<String>,
middleware: &AuthMiddleware,
) -> Result<AuthContext, AuthError> {
match token {
Some(token) => middleware
.validate_token_async(&token)
.await
.map(build_auth_context_from_claims),
None => Ok(AuthContext::unauthenticated()),
}
}
pub fn build_auth_context_from_claims(claims: Claims) -> AuthContext {
let exp = claims.exp();
let user_id = claims.user_id();
let mut custom_claims = claims.sanitized_custom();
let sub = claims.sub().to_string();
let roles = claims.into_roles();
custom_claims.insert("sub".to_string(), serde_json::Value::String(sub));
let ctx = match user_id {
Some(uuid) => AuthContext::authenticated(uuid, roles, custom_claims),
None => {
AuthContext::authenticated_without_uuid(roles, custom_claims)
}
};
ctx.with_token_exp(exp)
}
pub async fn auth_middleware(
State(middleware): State<Arc<AuthMiddleware>>,
req: Request<Body>,
next: Next,
) -> Response {
let token = match extract_token(&req) {
Ok(token) => token,
Err(e) => {
let (ip, ua) = extract_auth_diag(&req);
tracing::warn!(error = %e, "Invalid authorization header");
emit_auth_failure("invalid_header", &e.to_string(), req.uri().path(), ip, ua);
return super::response::RpcResponse::error(super::response::RpcError::unauthorized(
"Invalid authorization header",
))
.into_response();
}
};
tracing::trace!(
token_present = token.is_some(),
"Auth middleware processing request"
);
let auth_context = match extract_auth_context_async(token, &middleware).await {
Ok(auth_context) => auth_context,
Err(e) => {
let (ip, ua) = extract_auth_diag(&req);
let reason = match &e {
AuthError::TokenExpired => "token_expired",
AuthError::InvalidToken(_) => "invalid_token",
AuthError::MissingHeader => "missing_token",
AuthError::InvalidHeader => "invalid_header",
};
tracing::warn!(error = %e, "Token validation failed");
emit_auth_failure(reason, &e.to_string(), req.uri().path(), ip, ua);
return super::response::RpcResponse::error(super::response::RpcError::unauthorized(
"Invalid authentication token",
))
.into_response();
}
};
tracing::trace!(
authenticated = auth_context.is_authenticated(),
"Auth context created"
);
let should_set_cookie =
auth_context.is_authenticated() && middleware.config.jwt_secret.is_some();
let has_session_cookie = req
.headers()
.get(header::COOKIE)
.and_then(|v| v.to_str().ok())
.map(|c| c.contains("forge_session="))
.unwrap_or(false);
let should_set_cookie = should_set_cookie && !has_session_cookie;
let cookie_ip = req
.extensions()
.get::<crate::gateway::ResolvedClientIp>()
.and_then(|r| r.0.clone());
let cookie_ua = crate::gateway::extract_header(req.headers(), "user-agent");
let mut req = req;
req.extensions_mut().insert(auth_context.clone());
let mut response = next.run(req).await;
if should_set_cookie
&& let Some(subject) = auth_context.subject()
&& let Some(secret) = &middleware.config.jwt_secret
{
let cookie_ttl = middleware.config.session_cookie_ttl_secs;
let cookie_value = sign_session_cookie(
subject,
secret,
cookie_ttl,
cookie_ip.as_deref(),
cookie_ua.as_deref(),
);
let cookie = format!(
"forge_session={cookie_value}; Path=/_api/oauth/; HttpOnly; SameSite=Lax; Secure; Max-Age={cookie_ttl}"
);
if let Ok(val) = axum::http::HeaderValue::from_str(&cookie) {
response.headers_mut().append(header::SET_COOKIE, val);
}
}
response
}
fn coarsen_ip(ip: &str) -> String {
if let Ok(addr) = ip.parse::<std::net::IpAddr>() {
match addr {
std::net::IpAddr::V4(v4) => {
let o = v4.octets();
format!(
"{}.{}.{}",
o.first().copied().unwrap_or(0),
o.get(1).copied().unwrap_or(0),
o.get(2).copied().unwrap_or(0),
)
}
std::net::IpAddr::V6(v6) => {
let s = v6.segments();
format!(
"{:x}:{:x}:{:x}",
s.first().copied().unwrap_or(0),
s.get(1).copied().unwrap_or(0),
s.get(2).copied().unwrap_or(0),
)
}
}
} else {
String::new()
}
}
fn hash_ua(ua: &str) -> String {
let hash = Sha256::digest(ua.as_bytes());
let mut out = String::with_capacity(16);
for byte in hash.as_slice().iter().take(8) {
use std::fmt::Write;
let _ = write!(out, "{byte:02x}");
}
out
}
pub fn sign_session_cookie(
subject: &str,
secret: &str,
ttl_secs: i64,
client_ip: Option<&str>,
user_agent: Option<&str>,
) -> String {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use hmac::{Hmac, Mac};
use sha2::Sha256;
let expiry = chrono::Utc::now().timestamp() + ttl_secs;
let ip_prefix = client_ip.map(coarsen_ip).unwrap_or_default();
let ua_hash = user_agent.map(hash_ua).unwrap_or_default();
let encoded_subject = URL_SAFE_NO_PAD.encode(subject.as_bytes());
let payload = format!("{encoded_subject}:{expiry}");
let binding = format!("{payload}.{ip_prefix}.{ua_hash}");
let Ok(mut mac) = Hmac::<Sha256>::new_from_slice(secret.as_bytes()) else {
return String::new();
};
mac.update(binding.as_bytes());
let sig = URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes());
format!("{payload}.{sig}")
}
#[cfg(feature = "mcp-oauth")]
pub fn verify_session_cookie(
cookie_value: &str,
secret: &str,
client_ip: Option<&str>,
user_agent: Option<&str>,
) -> Option<String> {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use hmac::{Hmac, Mac};
use sha2::Sha256;
let (payload, sig_encoded) = cookie_value.rsplit_once('.')?;
let ip_prefix = client_ip.map(coarsen_ip).unwrap_or_default();
let ua_hash = user_agent.map(hash_ua).unwrap_or_default();
let binding = format!("{payload}.{ip_prefix}.{ua_hash}");
let sig_bytes = URL_SAFE_NO_PAD.decode(sig_encoded).ok()?;
let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).ok()?;
mac.update(binding.as_bytes());
mac.verify_slice(&sig_bytes).ok()?;
let (encoded_subject, expiry_str) = payload.rsplit_once(':')?;
let expiry: i64 = expiry_str.parse().ok()?;
if chrono::Utc::now().timestamp() > expiry {
return None;
}
let subject_bytes = URL_SAFE_NO_PAD.decode(encoded_subject).ok()?;
let subject = String::from_utf8(subject_bytes).ok()?;
Some(subject)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
mod tests {
use super::*;
#[cfg(feature = "mcp-oauth")]
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
#[cfg(feature = "mcp-oauth")]
use hmac::{Hmac, Mac};
use jsonwebtoken::{EncodingKey, Header, encode};
#[cfg(feature = "mcp-oauth")]
use sha2::Sha256;
fn create_test_claims(expired: bool) -> Claims {
use forge_core::auth::ClaimsBuilder;
let mut builder = ClaimsBuilder::new().subject("test-user-id").role("user");
if expired {
builder = builder.duration_secs(-3600); } else {
builder = builder.duration_secs(3600); }
builder.build().unwrap()
}
fn create_test_token(claims: &Claims, secret: &str) -> String {
encode(
&Header::default(),
claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap()
}
#[cfg(feature = "mcp-oauth")]
fn session_cookie_with_expiry(subject: &str, secret: &str, expiry: i64) -> String {
let encoded_subject = URL_SAFE_NO_PAD.encode(subject.as_bytes());
let payload = format!("{encoded_subject}:{expiry}");
let ip_prefix = coarsen_ip("192.168.1.42");
let ua_hash = hash_ua("TestAgent/1.0");
let binding = format!("{payload}.{ip_prefix}.{ua_hash}");
let mut mac =
Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key");
mac.update(binding.as_bytes());
let sig = URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes());
format!("{payload}.{sig}")
}
#[cfg(feature = "mcp-oauth")]
#[test]
fn test_coarsen_ip_masks_correctly() {
assert_eq!(coarsen_ip("192.168.1.42"), "192.168.1");
assert_eq!(coarsen_ip("10.0.0.1"), "10.0.0");
assert_eq!(coarsen_ip("2001:db8:85a3::8a2e:370:7334"), "2001:db8:85a3");
assert_eq!(coarsen_ip("not-an-ip"), "");
}
#[cfg(feature = "mcp-oauth")]
#[test]
fn test_hash_ua_deterministic() {
let h1 = hash_ua("Mozilla/5.0");
let h2 = hash_ua("Mozilla/5.0");
assert_eq!(h1, h2);
assert_ne!(hash_ua("Mozilla/5.0"), hash_ua("Chrome/100"));
}
#[test]
fn sanitize_for_log_strips_control_chars_and_truncates() {
assert_eq!(sanitize_for_log("normal-kid"), "normal-kid");
assert_eq!(sanitize_for_log("\x1b[2K\rok"), "[2Kok");
assert_eq!(sanitize_for_log("a\n\r\tb"), "ab");
let long = "x".repeat(100);
assert_eq!(sanitize_for_log(&long).len(), 64);
}
#[test]
fn test_auth_config_default() {
let config = AuthConfig::default();
assert_eq!(config.algorithm, JwtAlgorithm::HS256);
assert_eq!(config.mode, AuthMode::Production);
assert!(!config.skips_verification());
}
#[test]
fn test_auth_config_dev_mode() {
let config = AuthConfig::dev_mode().expect("dev_mode outside production");
assert_eq!(config.mode, AuthMode::Development);
assert!(config.skips_verification());
}
#[test]
fn test_auth_middleware_permissive() {
let middleware = AuthMiddleware::permissive().expect("permissive outside production");
assert!(middleware.config.skips_verification());
}
#[test]
fn test_dev_mode_refuses_in_production() {
let result = AuthConfig::dev_mode_with_env(DevModeEnv {
forge_env: Some("production".into()),
..DevModeEnv::default()
});
assert!(matches!(result, Err(forge_core::ForgeError::Config { .. })));
}
#[test]
fn test_dev_mode_refuses_case_insensitive() {
for v in ["Production", "PRODUCTION", "production"] {
let result = AuthConfig::dev_mode_with_env(DevModeEnv {
forge_env: Some(v.into()),
..DevModeEnv::default()
});
assert!(matches!(result, Err(forge_core::ForgeError::Config { .. })));
}
}
#[test]
fn test_dev_mode_refuses_node_env_production() {
let result = AuthConfig::dev_mode_with_env(DevModeEnv {
node_env: Some("production".into()),
..DevModeEnv::default()
});
assert!(matches!(result, Err(forge_core::ForgeError::Config { .. })));
}
#[test]
fn test_dev_mode_refuses_cloud_platform_indicators() {
for (field, val) in [
("RAILWAY_ENVIRONMENT", "production"),
("K_SERVICE", "my-svc"),
("FLY_APP_NAME", "my-app"),
("KUBERNETES_SERVICE_HOST", "10.0.0.1"),
("AWS_EXECUTION_ENV", "AWS_ECS_FARGATE"),
] {
let mut env = DevModeEnv::default();
match field {
"RAILWAY_ENVIRONMENT" => env.railway_environment = Some(val.into()),
"K_SERVICE" => env.k_service = Some(val.into()),
"FLY_APP_NAME" => env.fly_app_name = Some(val.into()),
"KUBERNETES_SERVICE_HOST" => env.kubernetes_service_host = Some(val.into()),
"AWS_EXECUTION_ENV" => env.aws_execution_env = Some(val.into()),
_ => {}
}
let result = AuthConfig::dev_mode_with_env(env);
assert!(
matches!(result, Err(forge_core::ForgeError::Config { .. })),
"{field} should block dev mode"
);
}
}
#[test]
fn test_dev_mode_allows_other_env_values() {
for forge_env in [None, Some("development"), Some("staging"), Some("")] {
let result = AuthConfig::dev_mode_with_env(DevModeEnv {
forge_env: forge_env.map(String::from),
..DevModeEnv::default()
});
assert!(result.is_ok());
}
}
#[tokio::test]
async fn test_valid_token_with_correct_secret() {
let secret = "test-secret-key";
let config = AuthConfig::with_secret(secret);
let middleware = AuthMiddleware::new(config);
let claims = create_test_claims(false);
let token = create_test_token(&claims, secret);
let result = middleware.validate_token_async(&token).await;
assert!(result.is_ok());
let validated_claims = result.unwrap();
assert_eq!(validated_claims.sub(), "test-user-id");
}
#[tokio::test]
async fn test_valid_token_with_wrong_secret() {
let config = AuthConfig::with_secret("correct-secret");
let middleware = AuthMiddleware::new(config);
let claims = create_test_claims(false);
let token = create_test_token(&claims, "wrong-secret");
let result = middleware.validate_token_async(&token).await;
assert!(result.is_err());
match result {
Err(AuthError::InvalidToken(_)) => {}
_ => panic!("Expected InvalidToken error"),
}
}
#[tokio::test]
async fn test_expired_token() {
let secret = "test-secret";
let config = AuthConfig::with_secret(secret);
let middleware = AuthMiddleware::new(config);
let claims = create_test_claims(true); let token = create_test_token(&claims, secret);
let result = middleware.validate_token_async(&token).await;
assert!(result.is_err());
match result {
Err(AuthError::TokenExpired) => {}
_ => panic!("Expected TokenExpired error"),
}
}
#[tokio::test]
async fn test_tampered_token() {
let secret = "test-secret";
let config = AuthConfig::with_secret(secret);
let middleware = AuthMiddleware::new(config);
let claims = create_test_claims(false);
let mut token = create_test_token(&claims, secret);
if let Some(last_char) = token.pop() {
let replacement = if last_char == 'a' { 'b' } else { 'a' };
token.push(replacement);
}
let result = middleware.validate_token_async(&token).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_dev_mode_skips_signature() {
let config = AuthConfig::dev_mode().expect("dev_mode outside production");
let middleware = AuthMiddleware::new(config);
let claims = create_test_claims(false);
let token = create_test_token(&claims, "any-secret");
let result = middleware.validate_token_async(&token).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_dev_mode_still_checks_expiration() {
let config = AuthConfig::dev_mode().expect("dev_mode outside production");
let middleware = AuthMiddleware::new(config);
let claims = create_test_claims(true); let token = create_test_token(&claims, "any-secret");
let result = middleware.validate_token_async(&token).await;
assert!(result.is_err());
match result {
Err(AuthError::TokenExpired) => {}
_ => panic!("Expected TokenExpired error even in dev mode"),
}
}
#[tokio::test]
async fn test_invalid_token_format() {
let config = AuthConfig::with_secret("secret");
let middleware = AuthMiddleware::new(config);
let result = middleware.validate_token_async("not-a-valid-jwt").await;
assert!(result.is_err());
match result {
Err(AuthError::InvalidToken(_)) => {}
_ => panic!("Expected InvalidToken error"),
}
}
#[test]
fn test_algorithm_conversion() {
assert_eq!(Algorithm::from(JwtAlgorithm::HS256), Algorithm::HS256);
assert_eq!(Algorithm::from(JwtAlgorithm::RS256), Algorithm::RS256);
}
#[test]
fn test_is_hmac_and_is_rsa() {
let hmac_config = AuthConfig::with_secret("test");
assert!(hmac_config.is_hmac());
assert!(!hmac_config.is_rsa());
let rsa_config = AuthConfig {
algorithm: JwtAlgorithm::RS256,
..Default::default()
};
assert!(!rsa_config.is_hmac());
assert!(rsa_config.is_rsa());
}
#[test]
fn test_extract_token_rejects_non_bearer_header() {
let req = Request::builder()
.header(axum::http::header::AUTHORIZATION, "Basic abc")
.body(Body::empty())
.unwrap();
let result = extract_token(&req);
assert!(matches!(result, Err(AuthError::InvalidHeader)));
}
#[test]
fn test_build_auth_context_from_non_uuid_claims_preserves_subject() {
let claims = Claims::builder()
.subject("clerk_user_123")
.role("member")
.claim("tenant_id", serde_json::json!("tenant-1"))
.unwrap()
.build()
.unwrap();
let auth = build_auth_context_from_claims(claims);
assert!(auth.is_authenticated());
assert!(auth.user_id().is_none());
assert_eq!(auth.subject(), Some("clerk_user_123"));
assert_eq!(auth.principal_id(), Some("clerk_user_123".to_string()));
assert!(auth.has_role("member"));
assert_eq!(
auth.claim("sub"),
Some(&serde_json::json!("clerk_user_123"))
);
}
#[cfg(feature = "mcp-oauth")]
#[test]
fn test_verify_session_cookie_round_trip_and_tamper_detection() {
let ip = Some("192.168.1.42");
let ua = Some("TestAgent/1.0");
let cookie = sign_session_cookie("user-123", "session-secret", 86400, ip, ua);
assert_eq!(
verify_session_cookie(&cookie, "session-secret", ip, ua),
Some("user-123".to_string())
);
let mut tampered = cookie.clone();
if let Some(last_char) = tampered.pop() {
tampered.push(if last_char == 'a' { 'b' } else { 'a' });
}
assert_eq!(
verify_session_cookie(&tampered, "session-secret", ip, ua),
None
);
assert_eq!(verify_session_cookie(&cookie, "wrong-secret", ip, ua), None);
}
#[cfg(feature = "mcp-oauth")]
#[test]
fn test_verify_session_cookie_rejects_expired_cookie() {
let expired_cookie = session_cookie_with_expiry(
"user-123",
"session-secret",
chrono::Utc::now().timestamp() - 1,
);
assert_eq!(
verify_session_cookie(
&expired_cookie,
"session-secret",
Some("192.168.1.42"),
Some("TestAgent/1.0"),
),
None
);
}
#[cfg(feature = "mcp-oauth")]
#[test]
fn test_verify_session_cookie_rejects_binding_mismatch() {
let ip = Some("192.168.1.42");
let ua = Some("TestAgent/1.0");
let cookie = sign_session_cookie("user-123", "session-secret", 86400, ip, ua);
assert_eq!(
verify_session_cookie(&cookie, "session-secret", Some("10.0.0.1"), ua),
None
);
assert_eq!(
verify_session_cookie(&cookie, "session-secret", ip, Some("OtherBrowser/2.0")),
None
);
assert_eq!(
verify_session_cookie(&cookie, "session-secret", None, None),
None
);
}
#[cfg(feature = "mcp-oauth")]
#[test]
fn test_session_cookie_round_trips_subject_with_dots() {
let ip = Some("192.168.1.42");
let ua = Some("TestAgent/1.0");
let subject = "clerk.user.abc.123";
let cookie = sign_session_cookie(subject, "session-secret", 86400, ip, ua);
assert_eq!(
verify_session_cookie(&cookie, "session-secret", ip, ua),
Some(subject.to_string())
);
}
#[tokio::test]
async fn test_extract_auth_context_async_invalid_token_errors() {
let middleware = AuthMiddleware::new(AuthConfig::with_secret("secret"));
let result = extract_auth_context_async(Some("bad.token".to_string()), &middleware).await;
assert!(matches!(result, Err(AuthError::InvalidToken(_))));
}
fn legacy_secret(
secret: &str,
valid_for: chrono::Duration,
) -> forge_core::config::LegacySecret {
forge_core::config::LegacySecret {
secret: secret.into(),
valid_until: chrono::Utc::now() + valid_for,
}
}
#[tokio::test]
async fn test_legacy_secret_validates_token_signed_by_old_key() {
let old_secret = "old-secret-key-32-bytes-minimum!!";
let new_secret = "new-secret-key-32-bytes-minimum!!";
let config = AuthConfig {
jwt_secret: Some(new_secret.into()),
legacy_secrets: vec![legacy_secret(old_secret, chrono::Duration::hours(1))],
..AuthConfig::with_secret(new_secret)
};
let middleware = AuthMiddleware::new(config);
let claims = create_test_claims(false);
let token_from_old_key = create_test_token(&claims, old_secret);
let result = middleware.validate_token_async(&token_from_old_key).await;
assert!(
result.is_ok(),
"legacy-signed token should be accepted: {result:?}"
);
let token_from_new_key = create_test_token(&claims, new_secret);
assert!(
middleware
.validate_token_async(&token_from_new_key)
.await
.is_ok()
);
}
#[tokio::test]
async fn test_legacy_secret_still_rejects_unknown_key() {
let config = AuthConfig {
legacy_secrets: vec![legacy_secret(
"known-legacy-secret-32bytes-pad!!",
chrono::Duration::hours(1),
)],
..AuthConfig::with_secret("active-secret-key-32-bytes-pad!!")
};
let middleware = AuthMiddleware::new(config);
let claims = create_test_claims(false);
let token = create_test_token(&claims, "totally-unknown-secret-32bytes!!");
let result = middleware.validate_token_async(&token).await;
assert!(matches!(result, Err(AuthError::InvalidToken(_))));
}
#[tokio::test]
async fn test_expired_legacy_secret_is_dropped_at_construction() {
let active_secret = "active-secret-key-32-bytes-pad!!";
let retired_secret = "retired-secret-key-32-bytes-pad!!";
let config = AuthConfig {
legacy_secrets: vec![legacy_secret(retired_secret, -chrono::Duration::seconds(1))],
..AuthConfig::with_secret(active_secret)
};
let middleware = AuthMiddleware::new(config);
assert!(
middleware.legacy_hmac_keys.is_empty(),
"expired legacy secret should be dropped at construction"
);
let claims = create_test_claims(false);
let token = create_test_token(&claims, retired_secret);
let result = middleware.validate_token_async(&token).await;
assert!(
matches!(result, Err(AuthError::InvalidToken(_))),
"expired legacy-signed token must not validate, got: {result:?}"
);
}
fn raw_jwt_with_alg(alg: &str) -> String {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
let header = URL_SAFE_NO_PAD.encode(format!(r#"{{"alg":"{alg}","typ":"JWT"}}"#));
let exp = chrono::Utc::now().timestamp() + 3600;
let payload = URL_SAFE_NO_PAD.encode(format!(r#"{{"sub":"test","exp":{exp},"iat":0}}"#));
let sig = URL_SAFE_NO_PAD.encode(b"fakesignature");
format!("{header}.{payload}.{sig}")
}
#[tokio::test]
async fn g3_10_jwt_algorithm_pre_check() {
let middleware = AuthMiddleware::new(AuthConfig::with_secret(
"test-secret-key-32-bytes-minimum!!",
));
let token = raw_jwt_with_alg("RS256");
let result = middleware.validate_token_async(&token).await;
match result {
Err(AuthError::InvalidToken(msg)) => {
assert!(
msg.contains("does not match"),
"expected alg-mismatch message, got: {msg}"
);
}
other => panic!("expected InvalidToken from alg pre-check, got: {other:?}"),
}
}
#[tokio::test]
async fn g1_jwt_alg_none_rejected() {
let middleware = AuthMiddleware::new(AuthConfig::with_secret(
"test-secret-key-32-bytes-minimum!!",
));
let token = raw_jwt_with_alg("none");
let result = middleware.validate_token_async(&token).await;
assert!(
matches!(result, Err(AuthError::InvalidToken(_))),
"alg=none token must be rejected, got: {result:?}"
);
}
#[test]
fn test_secret_kid_is_deterministic() {
let kid_a = secret_kid(b"some-secret");
let kid_b = secret_kid(b"some-secret");
assert_eq!(kid_a, kid_b);
assert_eq!(kid_a.len(), 8, "kid should be 8 hex chars (4 bytes)");
assert_ne!(kid_a, secret_kid(b"different-secret"));
}
#[tokio::test]
async fn test_issued_token_carries_kid_header() {
use forge_core::TokenIssuer;
let secret = "issuer-secret-key-32-bytes-pad!!!";
let config = AuthConfig::with_secret(secret);
let issuer = HmacTokenIssuer::from_config(&config).expect("issuer for hmac");
let claims = create_test_claims(false);
let token = issuer.sign(&claims).expect("signed token");
let header = jsonwebtoken::decode_header(&token).expect("decodable header");
assert_eq!(
header.kid.as_deref(),
Some(secret_kid(secret.as_bytes()).as_str()),
"kid in header must match SHA-256 prefix of the secret"
);
}
#[tokio::test]
async fn test_kid_matched_legacy_token_validates() {
let active_secret = "active-secret-key-32-bytes-pad!!!";
let retired_secret = "legacy-secret-key-32-bytes-pad!!!";
let retired_kid = secret_kid(retired_secret.as_bytes());
let config = AuthConfig {
legacy_secrets: vec![legacy_secret(retired_secret, chrono::Duration::hours(1))],
..AuthConfig::with_secret(active_secret)
};
let middleware = AuthMiddleware::new(config);
let claims = create_test_claims(false);
let mut header = Header::new(Algorithm::HS256);
header.kid = Some(retired_kid);
let token = encode(
&header,
&claims,
&EncodingKey::from_secret(retired_secret.as_bytes()),
)
.expect("encode legacy-signed token");
let result = middleware.validate_token_async(&token).await;
assert!(
result.is_ok(),
"kid-tagged legacy token must validate: {result:?}"
);
}
#[tokio::test]
async fn test_external_token_without_kid_still_validates() {
let secret = "shared-hmac-secret-32-bytes-pad!!";
let middleware = AuthMiddleware::new(AuthConfig::with_secret(secret));
let claims = create_test_claims(false);
let token = create_test_token(&claims, secret);
let header = jsonwebtoken::decode_header(&token).expect("decodable header");
assert!(header.kid.is_none(), "test fixture must not set kid");
let result = middleware.validate_token_async(&token).await;
assert!(
result.is_ok(),
"kidless external token must validate: {result:?}"
);
}
#[tokio::test]
async fn test_unknown_kid_falls_back_to_full_scan() {
let secret = "active-secret-key-32-bytes-pad!!!";
let middleware = AuthMiddleware::new(AuthConfig::with_secret(secret));
let claims = create_test_claims(false);
let mut header = Header::new(Algorithm::HS256);
header.kid = Some("deadbeef".to_string());
let token = encode(
&header,
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.expect("encode token with unknown kid");
let result = middleware.validate_token_async(&token).await;
assert!(
result.is_ok(),
"unknown-kid token must still validate via fallback: {result:?}"
);
}
#[tokio::test]
async fn rsa_token_without_kid_rejected_when_jwks_require_kid_is_true() {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
let config = AuthConfig {
algorithm: JwtAlgorithm::RS256,
jwks_client: Some(Arc::new(
JwksClient::new("http://example.invalid".into(), 3600).unwrap(),
)),
jwks_require_kid: true,
..AuthConfig::default()
};
let middleware = AuthMiddleware::new(config);
let header_json = r#"{"alg":"RS256","typ":"JWT"}"#;
let claims = create_test_claims(false);
let claims_json = serde_json::to_string(&claims).unwrap();
let header_b64 = URL_SAFE_NO_PAD.encode(header_json.as_bytes());
let claims_b64 = URL_SAFE_NO_PAD.encode(claims_json.as_bytes());
let token = format!("{header_b64}.{claims_b64}.fake-signature");
let result = middleware.validate_token_async(&token).await;
assert!(result.is_err(), "kidless RS256 token must be rejected");
let err = result.unwrap_err();
match &err {
AuthError::InvalidToken(msg) => {
assert!(
msg.contains("missing kid"),
"error should mention missing kid, got: {msg}"
);
}
other => panic!("expected InvalidToken, got: {other:?}"),
}
}
}