use std::sync::Arc;
use axum::{
Json,
extract::{Query, State},
http::{StatusCode, header},
response::{IntoResponse, Redirect, Response},
};
use serde::{Deserialize, Serialize};
use crate::auth::{OidcServerClient, PkceStateStore};
pub struct AuthPkceState {
pub pkce_store: Arc<PkceStateStore>,
pub oidc_client: Arc<OidcServerClient>,
pub http_client: Arc<reqwest::Client>,
pub post_login_redirect_uri: Option<String>,
}
#[derive(Deserialize)]
pub struct AuthStartQuery {
redirect_uri: String,
}
#[derive(Deserialize)]
pub struct AuthCallbackQuery {
code: Option<String>,
state: Option<String>,
error: Option<String>,
error_description: Option<String>,
}
#[derive(Serialize)]
struct TokenJson {
access_token: String,
#[serde(skip_serializing_if = "Option::is_none")]
id_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
expires_in: Option<u64>,
token_type: &'static str,
}
fn auth_error(status: StatusCode, message: &str) -> Response {
(status, Json(serde_json::json!({ "error": message }))).into_response()
}
pub async fn auth_start(
State(state): State<Arc<AuthPkceState>>,
Query(q): Query<AuthStartQuery>,
) -> Response {
if q.redirect_uri.is_empty() {
return auth_error(StatusCode::BAD_REQUEST, "redirect_uri is required");
}
if q.redirect_uri.len() > 2048 {
return auth_error(StatusCode::BAD_REQUEST, "redirect_uri exceeds maximum length");
}
let (outbound_token, verifier) = match state.pkce_store.create_state(&q.redirect_uri).await {
Ok(v) => v,
Err(e) => {
tracing::error!("pkce create_state failed: {e}");
return auth_error(
StatusCode::INTERNAL_SERVER_ERROR,
"authorization flow could not be started",
);
},
};
let challenge = PkceStateStore::s256_challenge(&verifier);
let location = state.oidc_client.authorization_url(&outbound_token, &challenge, "S256");
Redirect::to(&location).into_response()
}
#[allow(clippy::cognitive_complexity)] pub async fn auth_callback(
State(state): State<Arc<AuthPkceState>>,
Query(q): Query<AuthCallbackQuery>,
) -> Response {
if let Some(err) = q.error {
let desc = q.error_description.as_deref().unwrap_or("(no description provided)");
tracing::warn!(oidc_error = %err, description = %desc, "OIDC provider returned error");
let client_message = match err.as_str() {
"access_denied" => "Access was denied",
"login_required" => "Authentication is required",
"invalid_request" | "invalid_scope" => "Invalid authorization request",
"server_error" | "temporarily_unavailable" => "Authorization server error",
_ => "Authorization failed",
};
return auth_error(StatusCode::BAD_REQUEST, client_message);
}
let (Some(code), Some(state_token)) = (q.code, q.state) else {
return auth_error(StatusCode::BAD_REQUEST, "missing code or state parameter");
};
let pkce = match state.pkce_store.consume_state(&state_token).await {
Ok(s) => s,
Err(e) => {
tracing::debug!(error = %e, "pkce consume_state failed");
return auth_error(StatusCode::BAD_REQUEST, &e.to_string());
},
};
let tokens = match state
.oidc_client
.exchange_code(&code, &pkce.verifier, &state.http_client)
.await
{
Ok(t) => t,
Err(e) => {
tracing::error!("token exchange failed: {e}");
return auth_error(StatusCode::BAD_GATEWAY, "token exchange with OIDC provider failed");
},
};
if let Some(redirect_uri) = &state.post_login_redirect_uri {
let max_age = tokens.expires_in.unwrap_or(300);
let token_escaped = tokens.access_token.replace('\\', r"\\").replace('"', r#"\""#);
let cookie = format!(
r#"__Host-access_token="{token_escaped}"; Path=/; HttpOnly; Secure; SameSite=Strict; Max-Age={max_age}"#,
);
let mut resp = Redirect::to(redirect_uri).into_response();
match cookie.parse() {
Ok(value) => {
resp.headers_mut().insert(header::SET_COOKIE, value);
},
Err(e) => {
tracing::error!("Failed to parse Set-Cookie header: {e}");
return auth_error(
StatusCode::INTERNAL_SERVER_ERROR,
"session cookie could not be set",
);
},
}
resp
} else {
Json(TokenJson {
access_token: tokens.access_token,
id_token: tokens.id_token,
expires_in: tokens.expires_in,
token_type: "Bearer",
})
.into_response()
}
}
#[derive(Deserialize)]
pub struct RevokeTokenRequest {
pub token: String,
}
#[derive(Serialize)]
pub struct RevokeTokenResponse {
pub revoked: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<String>,
}
pub struct RevocationRouteState {
pub revocation_manager: std::sync::Arc<crate::token_revocation::TokenRevocationManager>,
}
pub async fn revoke_token(
State(state): State<std::sync::Arc<RevocationRouteState>>,
Json(body): Json<RevokeTokenRequest>,
) -> Response {
#[derive(serde::Deserialize)]
struct MinimalClaims {
jti: Option<String>,
exp: Option<u64>,
}
let claims = match jsonwebtoken::dangerous::insecure_decode::<MinimalClaims>(&body.token) {
Ok(data) => data.claims,
Err(e) => {
return auth_error(StatusCode::BAD_REQUEST, &format!("Invalid token: {e}"));
},
};
let jti = match claims.jti {
Some(j) if !j.is_empty() => j,
_ => {
return auth_error(StatusCode::BAD_REQUEST, "Token has no jti claim");
},
};
let ttl_secs = claims
.exp
.and_then(|exp| {
let now = chrono::Utc::now().timestamp().cast_unsigned();
exp.checked_sub(now)
})
.unwrap_or(86400);
if let Err(e) = state.revocation_manager.revoke(&jti, ttl_secs).await {
tracing::error!(error = %e, "Failed to revoke token");
return auth_error(StatusCode::INTERNAL_SERVER_ERROR, "Failed to revoke token");
}
let expires_at = claims.exp.map(|exp| {
chrono::DateTime::from_timestamp(exp.cast_signed(), 0)
.map_or_else(|| exp.to_string(), |dt| dt.to_rfc3339())
});
Json(RevokeTokenResponse {
revoked: true,
expires_at,
})
.into_response()
}
#[derive(Deserialize)]
pub struct RevokeAllRequest {
pub sub: String,
}
#[derive(Serialize)]
pub struct RevokeAllResponse {
pub revoked_count: u64,
}
pub async fn revoke_all_tokens(
State(state): State<std::sync::Arc<RevocationRouteState>>,
Json(body): Json<RevokeAllRequest>,
) -> Response {
if body.sub.is_empty() {
return auth_error(StatusCode::BAD_REQUEST, "sub is required");
}
match state.revocation_manager.revoke_all_for_user(&body.sub).await {
Ok(count) => Json(RevokeAllResponse {
revoked_count: count,
})
.into_response(),
Err(e) => {
tracing::error!(error = %e, sub = %body.sub, "Failed to revoke tokens for user");
auth_error(StatusCode::INTERNAL_SERVER_ERROR, "Failed to revoke tokens")
},
}
}
pub struct AuthMeState {
pub expose_claims: Vec<String>,
}
pub async fn auth_me(
axum::extract::State(state): axum::extract::State<std::sync::Arc<AuthMeState>>,
axum::Extension(auth_user): axum::Extension<crate::middleware::AuthUser>,
) -> axum::response::Response {
use axum::{Json, response::IntoResponse as _};
let user = &auth_user.0;
let mut map = serde_json::Map::new();
map.insert("sub".to_owned(), serde_json::Value::String(user.user_id.clone()));
map.insert("user_id".to_owned(), serde_json::Value::String(user.user_id.clone()));
map.insert("expires_at".to_owned(), serde_json::Value::String(user.expires_at.to_rfc3339()));
for claim_name in &state.expose_claims {
if let Some(value) = user.extra_claims.get(claim_name) {
map.insert(claim_name.clone(), value.clone());
}
}
Json(serde_json::Value::Object(map)).into_response()
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use axum::{Extension, Router, body::Body, http::Request, routing::get};
use chrono::Utc;
use tower::ServiceExt as _;
use super::*;
use crate::{auth::PkceStateStore, middleware::AuthUser};
fn mock_pkce_store() -> Arc<PkceStateStore> {
Arc::new(PkceStateStore::new(600, None))
}
fn make_auth_user(
user_id: &str,
extra: std::collections::HashMap<String, serde_json::Value>,
) -> AuthUser {
AuthUser(fraiseql_core::security::AuthenticatedUser {
user_id: user_id.to_owned(),
scopes: vec![],
expires_at: Utc::now() + chrono::Duration::hours(1),
extra_claims: extra,
})
}
fn make_me_state(expose_claims: Vec<&str>) -> Arc<AuthMeState> {
Arc::new(AuthMeState {
expose_claims: expose_claims.into_iter().map(str::to_owned).collect(),
})
}
#[tokio::test]
async fn test_auth_me_always_returns_sub_user_id_expires_at() {
let app = Router::new()
.route("/auth/me", get(auth_me))
.layer(Extension(make_auth_user("user-123", std::collections::HashMap::new())))
.with_state(make_me_state(vec![]));
let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), axum::http::StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["sub"], "user-123");
assert_eq!(json["user_id"], "user-123");
assert!(json["expires_at"].is_string(), "expires_at must be present");
}
#[tokio::test]
async fn test_auth_me_expose_claims_filters_correctly() {
let mut extra = std::collections::HashMap::new();
extra.insert("email".to_owned(), serde_json::json!("alice@example.com"));
extra.insert("https://myapp.com/role".to_owned(), serde_json::json!("admin"));
let app = Router::new()
.route("/auth/me", get(auth_me))
.layer(Extension(make_auth_user("alice", extra)))
.with_state(make_me_state(vec!["email"]));
let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
let resp = app.oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["email"], "alice@example.com", "listed claim must appear");
assert!(json.get("https://myapp.com/role").is_none(), "unlisted claim must be absent");
}
#[tokio::test]
async fn test_auth_me_claim_absent_from_token_silently_omitted() {
let app = Router::new()
.route("/auth/me", get(auth_me))
.layer(Extension(make_auth_user("user-x", std::collections::HashMap::new())))
.with_state(make_me_state(vec!["tenant_id"]));
let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
let resp = app.oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(json.get("tenant_id").is_none(), "absent claim must not be null-padded");
assert_eq!(json["sub"], "user-x");
}
#[tokio::test]
async fn test_auth_me_namespaced_claim_in_expose_claims() {
let mut extra = std::collections::HashMap::new();
extra.insert("https://myapp.com/role".to_owned(), serde_json::json!("editor"));
let app = Router::new()
.route("/auth/me", get(auth_me))
.layer(Extension(make_auth_user("user-y", extra)))
.with_state(make_me_state(vec!["https://myapp.com/role"]));
let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
let resp = app.oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["https://myapp.com/role"], "editor");
}
fn mock_oidc_client() -> Arc<OidcServerClient> {
Arc::new(OidcServerClient::new(
"test-client",
"test-secret",
"https://api.example.com/auth/callback",
"https://provider.example.com/authorize",
"https://provider.example.com/token",
))
}
fn auth_router() -> Router {
let auth_state = Arc::new(AuthPkceState {
pkce_store: mock_pkce_store(),
oidc_client: mock_oidc_client(),
http_client: Arc::new(reqwest::Client::new()),
post_login_redirect_uri: None,
});
Router::new()
.route("/auth/start", get(auth_start))
.route("/auth/callback", get(auth_callback))
.with_state(auth_state)
}
#[tokio::test]
async fn test_auth_start_redirects_with_pkce_params() {
let app = auth_router();
let req = Request::builder()
.uri("/auth/start?redirect_uri=https%3A%2F%2Fapp.example.com%2Fcb")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert!(resp.status().is_redirection(), "expected redirect, got {}", resp.status());
let location = resp
.headers()
.get(header::LOCATION)
.and_then(|v| v.to_str().ok())
.expect("Location header must be present");
assert!(location.contains("response_type=code"), "missing response_type");
assert!(location.contains("code_challenge="), "missing code_challenge");
assert!(location.contains("code_challenge_method=S256"), "missing challenge method");
assert!(location.contains("state="), "missing state param");
assert!(location.contains("client_id=test-client"), "missing client_id");
}
#[tokio::test]
async fn test_auth_start_missing_redirect_uri_returns_400() {
let app = auth_router();
let req = Request::builder().uri("/auth/start").body(Body::empty()).unwrap();
let resp = app.oneshot(req).await.unwrap();
assert!(
resp.status().is_client_error(),
"missing redirect_uri must be a client error, got {}",
resp.status()
);
}
#[tokio::test]
async fn test_auth_callback_unknown_state_returns_400() {
let app = auth_router();
let req = Request::builder()
.uri("/auth/callback?code=abc&state=completely-unknown-state")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(json["error"].is_string(), "error field must be a string: {json}");
}
#[tokio::test]
async fn test_auth_callback_missing_code_returns_400() {
let app = auth_router();
let req = Request::builder()
.uri("/auth/callback?state=some-state-no-code")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_auth_start_oversized_redirect_uri_returns_400() {
let app = auth_router();
let long_uri = "https://example.com/".to_string() + &"a".repeat(2100);
let encoded = urlencoding::encode(&long_uri);
let req = Request::builder()
.uri(format!("/auth/start?redirect_uri={encoded}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(
json["error"].as_str().unwrap_or("").contains("maximum length"),
"error must mention length: {json}"
);
}
#[tokio::test]
async fn test_auth_callback_oidc_error_returns_mapped_message() {
let app = auth_router();
let req = Request::builder()
.uri("/auth/callback?error=access_denied&error_description=internal+tenant+info")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let error_msg = json["error"].as_str().unwrap_or("");
assert!(
!error_msg.contains("internal tenant info"),
"provider description must not be reflected to client: {error_msg}"
);
assert_eq!(error_msg, "Access was denied");
}
#[tokio::test]
async fn test_auth_callback_unknown_oidc_error_returns_generic_message() {
let app = auth_router();
let req = Request::builder()
.uri("/auth/callback?error=unknown_vendor_error&error_description=secret+details")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["error"].as_str().unwrap_or(""), "Authorization failed");
}
#[tokio::test]
async fn test_auth_callback_oidc_error_no_description_uses_fallback() {
let app = auth_router();
let req = Request::builder()
.uri("/auth/callback?error=access_denied")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["error"].as_str().unwrap_or(""), "Access was denied");
}
#[tokio::test]
async fn test_auth_start_to_callback_state_roundtrip_with_encryption() {
use crate::auth::{EncryptionAlgorithm, StateEncryptionService};
let enc = Arc::new(StateEncryptionService::from_raw_key(
&[0u8; 32],
EncryptionAlgorithm::Chacha20Poly1305,
));
let pkce_store = Arc::new(PkceStateStore::new(600, Some(enc)));
let auth_state = Arc::new(AuthPkceState {
pkce_store,
oidc_client: mock_oidc_client(),
http_client: Arc::new(reqwest::Client::new()),
post_login_redirect_uri: None,
});
let app = Router::new()
.route("/auth/start", get(auth_start))
.route("/auth/callback", get(auth_callback))
.with_state(auth_state);
let req = Request::builder()
.uri("/auth/start?redirect_uri=https%3A%2F%2Fapp.example.com%2Fcb")
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert!(
resp.status().is_redirection(),
"expected redirect from /auth/start, got {}",
resp.status(),
);
let location = resp
.headers()
.get(header::LOCATION)
.and_then(|v| v.to_str().ok())
.expect("Location header must be set")
.to_string();
let parsed_location =
reqwest::Url::parse(&location).expect("Location header must be a valid URL");
let state_token = parsed_location
.query_pairs()
.find(|(k, _)| k == "state")
.map(|(_, v)| v.into_owned())
.expect("state= must appear in the redirect Location URL");
assert!(!state_token.is_empty(), "extracted state token must not be empty");
let callback_uri = format!("/auth/callback?code=test_code&state={state_token}");
let req2 = Request::builder().uri(&callback_uri).body(Body::empty()).unwrap();
let resp2 = app.clone().oneshot(req2).await.unwrap();
assert_ne!(
resp2.status(),
StatusCode::BAD_REQUEST,
"state from /auth/start must be accepted by /auth/callback; \
400 means the PKCE state was not found or decryption failed",
);
assert_eq!(
resp2.status(),
StatusCode::BAD_GATEWAY,
"token exchange should fail 502 (no real OIDC provider); got {}",
resp2.status(),
);
}
}