use std::sync::OnceLock;
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{Json, Router};
use serde::{Deserialize, Serialize};
use crate::extractors::Tenant;
use crate::tenancy::jwt_lifecycle::JwtLifecycle;
#[derive(Debug, Clone)]
pub struct Config {
pub prefix: String,
pub access_ttl_secs: i64,
pub refresh_ttl_secs: i64,
pub session_secret: Option<Vec<u8>>,
}
impl Default for Config {
fn default() -> Self {
Self {
prefix: "/api/auth".to_owned(),
access_ttl_secs: 900,
refresh_ttl_secs: 7 * 86400,
session_secret: None,
}
}
}
const MIN_HMAC_KEY_LEN: usize = 32;
impl Config {
fn build_jwt(&self) -> JwtLifecycle {
let secret = self.session_secret.clone().unwrap_or_else(|| {
std::env::var("RUSTANGO_SESSION_SECRET")
.unwrap_or_default()
.into_bytes()
});
assert!(
secret.len() >= MIN_HMAC_KEY_LEN,
"JWT signing key is {} bytes; need >= {MIN_HMAC_KEY_LEN}. Set \
RUSTANGO_SESSION_SECRET to a base64-encoded 32+ byte value \
(e.g. `openssl rand -base64 32`) or pass an explicit \
auth_routes::Config::session_secret. Refusing to start with a \
guessable key (would allow JWT forgery).",
secret.len(),
);
JwtLifecycle::new(secret)
.with_access_ttl(self.access_ttl_secs)
.with_refresh_ttl(self.refresh_ttl_secs)
}
#[cfg(feature = "config")]
#[must_use]
pub fn with_jwt_settings(mut self, s: &crate::config::JwtSettings) -> Self {
if let Some(v) = s.access_ttl_secs {
self.access_ttl_secs = i64::try_from(v).unwrap_or(i64::MAX);
}
if let Some(v) = s.refresh_ttl_secs {
self.refresh_ttl_secs = i64::try_from(v).unwrap_or(i64::MAX);
}
self
}
}
static JWT: OnceLock<JwtLifecycle> = OnceLock::new();
pub fn jwt_router(cfg: Config) -> Router<()> {
let _ = JWT.set(cfg.build_jwt());
Router::new()
.route(&format!("{}/login", cfg.prefix), post(login))
.route(&format!("{}/refresh", cfg.prefix), post(refresh))
.route(&format!("{}/logout", cfg.prefix), post(logout))
.route(&format!("{}/me", cfg.prefix), get(me))
}
fn jwt_handle() -> &'static JwtLifecycle {
JWT.get_or_init(|| Config::default().build_jwt())
}
#[derive(Debug, Deserialize)]
pub struct LoginInput {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct UserBrief {
pub user_id: i64,
pub username: String,
pub is_superuser: bool,
}
#[derive(Debug, Serialize)]
pub struct LoginOutput {
pub access: String,
pub refresh: String,
pub user: UserBrief,
}
async fn login(
t: Tenant,
headers: axum::http::HeaderMap,
Json(body): Json<LoginInput>,
) -> Result<Json<LoginOutput>, Response> {
use crate::core::Column as _;
use crate::signals::auth::{
meta_from_headers, send_user_logged_in, send_user_login_failed, AuthFailureReason,
UserLoggedInContext, UserLoginFailedContext,
};
use crate::sql::FetcherPool as _;
use crate::tenancy::auth::User;
let meta = meta_from_headers(&headers, Some("/auth/login"));
let fire_failed = |reason: AuthFailureReason| -> UserLoginFailedContext {
UserLoginFailedContext {
source: "jwt",
attempted_username: Some(body.username.clone()),
reason,
request: meta.clone(),
}
};
let users = User::objects()
.where_(User::username.eq(body.username.clone()))
.fetch_pool(t.pool())
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response())?;
let Some(user) = users.into_iter().next() else {
crate::tenancy::password::verify_dummy(&body.password);
send_user_login_failed(fire_failed(AuthFailureReason::InvalidCredentials)).await;
return Err((StatusCode::UNAUTHORIZED, "invalid credentials").into_response());
};
let uid = user.id.get().copied().unwrap_or(0);
#[cfg(feature = "cache")]
let lock_key = format!("tenant:{}:{}", t.org.slug, uid);
#[cfg(feature = "cache")]
if crate::account_lockout::shared().is_locked(&lock_key).await {
send_user_login_failed(fire_failed(AuthFailureReason::InvalidCredentials)).await;
return Err((StatusCode::UNAUTHORIZED, "invalid credentials").into_response());
}
let ok = crate::tenancy::password::verify(&body.password, &user.password_hash)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response())?;
if !user.active {
send_user_login_failed(fire_failed(AuthFailureReason::Inactive)).await;
return Err((StatusCode::UNAUTHORIZED, "invalid credentials").into_response());
}
if !ok {
#[cfg(feature = "cache")]
{
let _ = crate::account_lockout::shared()
.record_failure(&lock_key)
.await;
}
send_user_login_failed(fire_failed(AuthFailureReason::InvalidCredentials)).await;
return Err((StatusCode::UNAUTHORIZED, "invalid credentials").into_response());
}
#[cfg(feature = "cache")]
crate::account_lockout::shared().clear(&lock_key).await;
let user_id = uid;
send_user_logged_in(UserLoggedInContext {
source: "jwt",
user_id,
username: user.username.clone(),
is_superuser: user.is_superuser,
request: meta,
})
.await;
let custom = serde_json::json!({"tenant": t.org.slug});
let pair = jwt_handle()
.issue_pair_with(user_id, custom.as_object().cloned().unwrap_or_default())
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response())?;
Ok(Json(LoginOutput {
access: pair.access,
refresh: pair.refresh,
user: UserBrief {
user_id,
username: user.username,
is_superuser: user.is_superuser,
},
}))
}
#[derive(Debug, Deserialize)]
pub struct RefreshInput {
pub refresh: String,
}
#[derive(Debug, Serialize)]
pub struct RefreshOutput {
pub access: String,
pub refresh: String,
}
async fn refresh(
t: Tenant,
Json(body): Json<RefreshInput>,
) -> Result<Json<RefreshOutput>, Response> {
let claims = jwt_handle().verify_refresh(&body.refresh).ok_or_else(|| {
(StatusCode::UNAUTHORIZED, "invalid or expired refresh token").into_response()
})?;
let tenant_ok =
claims.custom_value("tenant").and_then(|v| v.as_str()) == Some(t.org.slug.as_str());
if !tenant_ok {
return Err((
StatusCode::UNAUTHORIZED,
"refresh token issued for a different tenant",
)
.into_response());
}
{
use crate::core::Column as _;
use crate::sql::FetcherPool as _;
use crate::tenancy::auth::User;
let users: Vec<User> = User::objects()
.where_(User::id.eq(claims.sub))
.fetch_pool(t.pool())
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response())?;
let still_active = users.into_iter().next().is_some_and(|u| u.active);
if !still_active {
return Err(
(StatusCode::UNAUTHORIZED, "invalid or expired refresh token").into_response(),
);
}
}
let pair = jwt_handle().refresh(&body.refresh).ok_or_else(|| {
(StatusCode::UNAUTHORIZED, "invalid or expired refresh token").into_response()
})?;
Ok(Json(RefreshOutput {
access: pair.access,
refresh: pair.refresh,
}))
}
async fn logout(
t: Tenant,
headers: axum::http::HeaderMap,
bearer: Bearer,
) -> Result<StatusCode, Response> {
use crate::signals::auth::{meta_from_headers, send_user_logged_out, UserLoggedOutContext};
let claims = jwt_handle().verify_access(&bearer.0);
if let Some(c) = &claims {
if c.custom_value("tenant").and_then(|v| v.as_str()) != Some(t.org.slug.as_str()) {
return Err((
StatusCode::UNAUTHORIZED,
"token issued for a different tenant",
)
.into_response());
}
}
let user_id = claims.map(|c| c.sub);
let meta = meta_from_headers(&headers, Some("/auth/logout"));
jwt_handle().revoke(&bearer.0);
send_user_logged_out(UserLoggedOutContext {
source: "jwt",
user_id,
username: None,
request: meta,
})
.await;
Ok(StatusCode::NO_CONTENT)
}
pub fn verify_for_tenant(bearer: &str, expected_slug: &str) -> Result<i64, &'static str> {
let claims = jwt_handle()
.verify_access(bearer)
.ok_or("invalid or expired token")?;
let claim_tenant = claims
.custom_value("tenant")
.and_then(|v| v.as_str())
.ok_or("token missing tenant binding")?;
if claim_tenant != expected_slug {
return Err("token issued for different tenant");
}
Ok(claims.sub)
}
async fn me(t: Tenant, bearer: Bearer) -> Result<Json<UserBrief>, Response> {
use crate::core::Column as _;
use crate::sql::FetcherPool as _;
use crate::tenancy::auth::User;
let user_id = verify_for_tenant(&bearer.0, &t.org.slug)
.map_err(|msg| (StatusCode::UNAUTHORIZED, msg).into_response())?;
let users = User::objects()
.where_(User::id.eq(user_id))
.fetch_pool(t.pool())
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response())?;
let user = users
.into_iter()
.next()
.ok_or_else(|| (StatusCode::UNAUTHORIZED, "user not found").into_response())?;
if !user.active {
return Err((StatusCode::FORBIDDEN, "account inactive").into_response());
}
Ok(Json(UserBrief {
user_id: user.id.get().copied().unwrap_or(0),
username: user.username,
is_superuser: user.is_superuser,
}))
}
struct Bearer(String);
impl<S: Send + Sync> FromRequestParts<S> for Bearer {
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.map(|t| Bearer(t.trim().to_owned()))
.ok_or_else(|| (StatusCode::UNAUTHORIZED, "missing Bearer token").into_response())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_default_paths_match_documentation() {
let cfg = Config::default();
assert_eq!(cfg.prefix, "/api/auth");
assert_eq!(cfg.access_ttl_secs, 900);
assert_eq!(cfg.refresh_ttl_secs, 7 * 86400);
assert!(cfg.session_secret.is_none());
}
#[test]
#[should_panic(expected = "JWT signing key")]
fn build_jwt_panics_on_empty_secret() {
let cfg = Config {
session_secret: Some(Vec::new()),
..Default::default()
};
let _ = cfg.build_jwt();
}
#[test]
#[should_panic(expected = "JWT signing key")]
fn build_jwt_panics_on_short_secret() {
let cfg = Config {
session_secret: Some(b"too-short-key".to_vec()),
..Default::default()
};
let _ = cfg.build_jwt();
}
#[test]
fn config_uses_explicit_secret_when_set() {
let cfg = Config {
session_secret: Some(b"super-secret-key-for-tests-32b!!".to_vec()),
..Default::default()
};
assert!(cfg.session_secret.as_ref().unwrap().len() >= 32);
let jwt = cfg.build_jwt();
let token = jwt.issue_pair(42);
let claims = jwt.verify_access(&token.access).expect("access valid");
assert_eq!(claims.sub, 42);
}
#[test]
fn jwt_router_mounts_all_four_endpoints() {
let cfg = Config {
session_secret: Some(b"router-smoke-test-secret-32-byte".to_vec()),
..Default::default()
};
let r = jwt_router(cfg);
let _ = r;
}
#[cfg(feature = "config")]
#[test]
fn with_jwt_settings_overrides_ttls() {
let mut s = crate::config::JwtSettings::default();
s.access_ttl_secs = Some(60); s.refresh_ttl_secs = Some(3600); let cfg = Config::default().with_jwt_settings(&s);
assert_eq!(cfg.access_ttl_secs, 60);
assert_eq!(cfg.refresh_ttl_secs, 3600);
}
#[cfg(feature = "config")]
#[test]
fn with_jwt_settings_unset_preserves_defaults() {
let s = crate::config::JwtSettings::default(); let cfg = Config::default().with_jwt_settings(&s);
assert_eq!(cfg.access_ttl_secs, 900); assert_eq!(cfg.refresh_ttl_secs, 7 * 86400); }
}