use std::sync::OnceLock;
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{Json, Router};
use serde::{Deserialize, Serialize};
use crate::extractors::Tenant;
use crate::tenancy::jwt_lifecycle::JwtLifecycle;
#[derive(Debug, Clone)]
pub struct Config {
pub prefix: String,
pub access_ttl_secs: i64,
pub refresh_ttl_secs: i64,
pub session_secret: Option<Vec<u8>>,
}
impl Default for Config {
fn default() -> Self {
Self {
prefix: "/api/auth".to_owned(),
access_ttl_secs: 900,
refresh_ttl_secs: 7 * 86400,
session_secret: None,
}
}
}
impl Config {
fn build_jwt(&self) -> JwtLifecycle {
let secret = self.session_secret.clone().unwrap_or_else(|| {
std::env::var("RUSTANGO_SESSION_SECRET")
.unwrap_or_default()
.into_bytes()
});
JwtLifecycle::new(secret)
.with_access_ttl(self.access_ttl_secs)
.with_refresh_ttl(self.refresh_ttl_secs)
}
#[cfg(feature = "config")]
#[must_use]
pub fn with_jwt_settings(mut self, s: &crate::config::JwtSettings) -> Self {
if let Some(v) = s.access_ttl_secs {
self.access_ttl_secs = i64::try_from(v).unwrap_or(i64::MAX);
}
if let Some(v) = s.refresh_ttl_secs {
self.refresh_ttl_secs = i64::try_from(v).unwrap_or(i64::MAX);
}
self
}
}
static JWT: OnceLock<JwtLifecycle> = OnceLock::new();
pub fn jwt_router(cfg: Config) -> Router<()> {
let _ = JWT.set(cfg.build_jwt());
Router::new()
.route(&format!("{}/login", cfg.prefix), post(login))
.route(&format!("{}/refresh", cfg.prefix), post(refresh))
.route(&format!("{}/logout", cfg.prefix), post(logout))
.route(&format!("{}/me", cfg.prefix), get(me))
}
fn jwt_handle() -> &'static JwtLifecycle {
JWT.get_or_init(|| Config::default().build_jwt())
}
#[derive(Debug, Deserialize)]
pub struct LoginInput {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct UserBrief {
pub user_id: i64,
pub username: String,
pub is_superuser: bool,
}
#[derive(Debug, Serialize)]
pub struct LoginOutput {
pub access: String,
pub refresh: String,
pub user: UserBrief,
}
async fn login(mut t: Tenant, Json(body): Json<LoginInput>) -> Result<Json<LoginOutput>, Response> {
use crate::core::Column as _;
use crate::tenancy::auth::User;
let users = User::objects()
.where_(User::username.eq(body.username.clone()))
.fetch_on(t.conn())
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response())?;
let user = users
.into_iter()
.next()
.ok_or_else(|| (StatusCode::UNAUTHORIZED, "invalid credentials").into_response())?;
if !user.active {
return Err((StatusCode::FORBIDDEN, "account inactive").into_response());
}
let ok = crate::tenancy::password::verify(&body.password, &user.password_hash)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response())?;
if !ok {
return Err((StatusCode::UNAUTHORIZED, "invalid credentials").into_response());
}
let user_id = user.id.get().copied().unwrap_or(0);
let custom = serde_json::json!({"tenant": t.org.slug});
let pair = jwt_handle()
.issue_pair_with(user_id, custom.as_object().cloned().unwrap_or_default())
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response())?;
Ok(Json(LoginOutput {
access: pair.access,
refresh: pair.refresh,
user: UserBrief {
user_id,
username: user.username,
is_superuser: user.is_superuser,
},
}))
}
#[derive(Debug, Deserialize)]
pub struct RefreshInput {
pub refresh: String,
}
#[derive(Debug, Serialize)]
pub struct RefreshOutput {
pub access: String,
pub refresh: String,
}
async fn refresh(Json(body): Json<RefreshInput>) -> Result<Json<RefreshOutput>, Response> {
let pair = jwt_handle().refresh(&body.refresh).ok_or_else(|| {
(StatusCode::UNAUTHORIZED, "invalid or expired refresh token").into_response()
})?;
Ok(Json(RefreshOutput {
access: pair.access,
refresh: pair.refresh,
}))
}
async fn logout(bearer: Bearer) -> Result<StatusCode, Response> {
jwt_handle().revoke(&bearer.0);
Ok(StatusCode::NO_CONTENT)
}
pub fn verify_for_tenant(bearer: &str, expected_slug: &str) -> Result<i64, &'static str> {
let claims = jwt_handle()
.verify_access(bearer)
.ok_or("invalid or expired token")?;
let claim_tenant = claims
.custom_value("tenant")
.and_then(|v| v.as_str())
.ok_or("token missing tenant binding")?;
if claim_tenant != expected_slug {
return Err("token issued for different tenant");
}
Ok(claims.sub)
}
async fn me(mut t: Tenant, bearer: Bearer) -> Result<Json<UserBrief>, Response> {
use crate::core::Column as _;
use crate::tenancy::auth::User;
let user_id = verify_for_tenant(&bearer.0, &t.org.slug)
.map_err(|msg| (StatusCode::UNAUTHORIZED, msg).into_response())?;
let users = User::objects()
.where_(User::id.eq(user_id))
.fetch_on(t.conn())
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response())?;
let user = users
.into_iter()
.next()
.ok_or_else(|| (StatusCode::UNAUTHORIZED, "user not found").into_response())?;
if !user.active {
return Err((StatusCode::FORBIDDEN, "account inactive").into_response());
}
Ok(Json(UserBrief {
user_id: user.id.get().copied().unwrap_or(0),
username: user.username,
is_superuser: user.is_superuser,
}))
}
struct Bearer(String);
impl<S: Send + Sync> FromRequestParts<S> for Bearer {
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.headers
.get(axum::http::header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.map(|t| Bearer(t.trim().to_owned()))
.ok_or_else(|| (StatusCode::UNAUTHORIZED, "missing Bearer token").into_response())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_default_paths_match_documentation() {
let cfg = Config::default();
assert_eq!(cfg.prefix, "/api/auth");
assert_eq!(cfg.access_ttl_secs, 900);
assert_eq!(cfg.refresh_ttl_secs, 7 * 86400);
assert!(cfg.session_secret.is_none());
}
#[test]
fn config_falls_back_to_env_secret_when_unset() {
let _ = Config::default().build_jwt();
}
#[test]
fn config_uses_explicit_secret_when_set() {
let cfg = Config {
session_secret: Some(b"super-secret-key-for-tests".to_vec()),
..Default::default()
};
let jwt = cfg.build_jwt();
let token = jwt.issue_pair(42);
let claims = jwt.verify_access(&token.access).expect("access valid");
assert_eq!(claims.sub, 42);
}
#[test]
fn jwt_router_mounts_all_four_endpoints() {
let r = jwt_router(Config::default());
let _ = r;
}
#[cfg(feature = "config")]
#[test]
fn with_jwt_settings_overrides_ttls() {
let mut s = crate::config::JwtSettings::default();
s.access_ttl_secs = Some(60); s.refresh_ttl_secs = Some(3600); let cfg = Config::default().with_jwt_settings(&s);
assert_eq!(cfg.access_ttl_secs, 60);
assert_eq!(cfg.refresh_ttl_secs, 3600);
}
#[cfg(feature = "config")]
#[test]
fn with_jwt_settings_unset_preserves_defaults() {
let s = crate::config::JwtSettings::default(); let cfg = Config::default().with_jwt_settings(&s);
assert_eq!(cfg.access_ttl_secs, 900); assert_eq!(cfg.refresh_ttl_secs, 7 * 86400); }
}