Skip to main content

fraiseql_server/routes/
auth.rs

1//! PKCE `OAuth2` route handlers: `/auth/start` and `/auth/callback`.
2//!
3//! These routes implement the `OAuth2` Authorization Code flow with PKCE
4//! (RFC 7636) for server-side relying-party use.  FraiseQL acts as the
5//! OAuth client; the OIDC provider performs the actual authentication.
6//!
7//! # Flow
8//!
9//! ```text
10//! GET /auth/start?redirect_uri=https://app.example.com/after-login
11//!   → 302 → OIDC provider /authorize?...&code_challenge=...&state=...
12//!
13//! GET /auth/callback?code=<code>&state=<state>
14//!   → [verify state, exchange code+verifier for tokens]
15//!   → 200 JSON { access_token, id_token, expires_in, token_type }
16//!   OR 302 + Set-Cookie (when post_login_redirect_uri is configured)
17//! ```
18//!
19//! Routes are only mounted when `[security.pkce] enabled = true` AND `[auth]`
20//! is configured in the compiled schema.  See `server.rs` for the wiring.
21
22use std::sync::Arc;
23
24use axum::{
25    Json,
26    extract::{Query, State},
27    http::{StatusCode, header},
28    response::{IntoResponse, Redirect, Response},
29};
30use serde::{Deserialize, Serialize};
31
32use crate::auth::{OidcServerClient, PkceStateStore};
33
34/// Shared state injected into both PKCE route handlers.
35pub struct AuthPkceState {
36    /// In-memory PKCE state store (encrypted when `state_encryption` is on).
37    pub pkce_store:              Arc<PkceStateStore>,
38    /// Server-side OIDC client for building authorize URLs and exchanging codes.
39    pub oidc_client:             Arc<OidcServerClient>,
40    /// Shared HTTP client for token-endpoint calls.
41    pub http_client:             Arc<reqwest::Client>,
42    /// When set, the callback redirects here with the token in a
43    /// `Secure; HttpOnly; SameSite=Strict` cookie instead of returning JSON.
44    pub post_login_redirect_uri: Option<String>,
45}
46
47// ---------------------------------------------------------------------------
48// Query parameter structs
49// ---------------------------------------------------------------------------
50
51/// Query parameters accepted by `GET /auth/start`.
52#[derive(Deserialize)]
53pub struct AuthStartQuery {
54    /// The URI within the **client application** to redirect to after a
55    /// successful login.  This is stored in the PKCE state store and
56    /// returned to the caller at callback time via the `redirect_uri` in
57    /// the consumed state.
58    redirect_uri: String,
59}
60
61/// Query parameters sent by the OIDC provider to `GET /auth/callback`.
62#[derive(Deserialize)]
63pub struct AuthCallbackQuery {
64    /// Authorization code to exchange for tokens.
65    code:              Option<String>,
66    /// State token for CSRF and PKCE state lookup.
67    state:             Option<String>,
68    /// OIDC provider error code (e.g. `"access_denied"`).
69    error:             Option<String>,
70    /// Human-readable error description from the provider.
71    error_description: Option<String>,
72}
73
74// ---------------------------------------------------------------------------
75// Response body (JSON path)
76// ---------------------------------------------------------------------------
77
78#[derive(Serialize)]
79struct TokenJson {
80    access_token: String,
81    #[serde(skip_serializing_if = "Option::is_none")]
82    id_token:     Option<String>,
83    #[serde(skip_serializing_if = "Option::is_none")]
84    expires_in:   Option<u64>,
85    token_type:   &'static str,
86}
87
88// ---------------------------------------------------------------------------
89// Helpers
90// ---------------------------------------------------------------------------
91
92fn auth_error(status: StatusCode, message: &str) -> Response {
93    (status, Json(serde_json::json!({ "error": message }))).into_response()
94}
95
96// ---------------------------------------------------------------------------
97// GET /auth/start
98// ---------------------------------------------------------------------------
99
100/// Initiate a PKCE authorization code flow.
101///
102/// Generates a `code_verifier` and `code_challenge`, stores state in the
103/// [`PkceStateStore`], then redirects the user-agent to the OIDC provider.
104///
105/// # Query parameters
106///
107/// - `redirect_uri` — **required**: the client application's callback URI.
108///
109/// # Responses
110///
111/// - `302` — redirect to the OIDC provider's `/authorize` endpoint.
112/// - `400` — `redirect_uri` is missing.
113/// - `500` — internal error generating state (essentially impossible).
114pub async fn auth_start(
115    State(state): State<Arc<AuthPkceState>>,
116    Query(q): Query<AuthStartQuery>,
117) -> Response {
118    if q.redirect_uri.is_empty() {
119        return auth_error(StatusCode::BAD_REQUEST, "redirect_uri is required");
120    }
121    // Enforce a length cap to prevent memory amplification via the PKCE state store
122    // (in-memory or Redis) and to limit encrypted state blob size.
123    if q.redirect_uri.len() > 2048 {
124        return auth_error(StatusCode::BAD_REQUEST, "redirect_uri exceeds maximum length");
125    }
126
127    let (outbound_token, verifier) = match state.pkce_store.create_state(&q.redirect_uri).await {
128        Ok(v) => v,
129        Err(e) => {
130            tracing::error!("pkce create_state failed: {e}");
131            return auth_error(
132                StatusCode::INTERNAL_SERVER_ERROR,
133                "authorization flow could not be started",
134            );
135        },
136    };
137
138    let challenge = PkceStateStore::s256_challenge(&verifier);
139    let location = state.oidc_client.authorization_url(&outbound_token, &challenge, "S256");
140
141    Redirect::to(&location).into_response()
142}
143
144// ---------------------------------------------------------------------------
145// GET /auth/callback
146// ---------------------------------------------------------------------------
147
148/// Complete the PKCE authorization code flow.
149///
150/// Validates the `state` parameter, recovers the `code_verifier`, then
151/// exchanges the authorization `code` at the OIDC token endpoint.
152///
153/// # Query parameters
154///
155/// - `code`  — authorization code from the provider.
156/// - `state` — state token (may be encrypted).
157///
158/// The provider may also call this endpoint with `?error=…` when the user
159/// denies access; those are surfaced as `400` responses.
160///
161/// # Responses
162///
163/// - `200` JSON `{ access_token, id_token?, expires_in?, token_type }`. Or `302` with `Set-Cookie`
164///   when `post_login_redirect_uri` is configured.
165/// - `400` — invalid/expired state, missing parameters, or provider error.
166/// - `502` — token exchange with the OIDC provider failed.
167#[allow(clippy::cognitive_complexity)] // Reason: OAuth callback handler with state validation, token exchange, and redirect logic
168pub async fn auth_callback(
169    State(state): State<Arc<AuthPkceState>>,
170    Query(q): Query<AuthCallbackQuery>,
171) -> Response {
172    // ── Surface OIDC provider errors immediately ──────────────────────────
173    if let Some(err) = q.error {
174        let desc = q.error_description.as_deref().unwrap_or("(no description provided)");
175        // Log the full provider response for debugging, but return only a
176        // fixed allowlisted message to the client to avoid leaking internal
177        // provider details (tenant info, stack traces) or enabling injection.
178        tracing::warn!(oidc_error = %err, description = %desc, "OIDC provider returned error");
179        let client_message = match err.as_str() {
180            "access_denied" => "Access was denied",
181            "login_required" => "Authentication is required",
182            "invalid_request" | "invalid_scope" => "Invalid authorization request",
183            "server_error" | "temporarily_unavailable" => "Authorization server error",
184            _ => "Authorization failed",
185        };
186        return auth_error(StatusCode::BAD_REQUEST, client_message);
187    }
188
189    // ── Validate required parameters ──────────────────────────────────────
190    let (Some(code), Some(state_token)) = (q.code, q.state) else {
191        return auth_error(StatusCode::BAD_REQUEST, "missing code or state parameter");
192    };
193
194    // ── Consume PKCE state (atomic remove) ───────────────────────────────
195    let pkce = match state.pkce_store.consume_state(&state_token).await {
196        Ok(s) => s,
197        Err(e) => {
198            // Both StateNotFound and StateExpired are client errors.
199            // Log at debug to avoid spamming warnings from probing attacks.
200            tracing::debug!(error = %e, "pkce consume_state failed");
201            return auth_error(StatusCode::BAD_REQUEST, &e.to_string());
202        },
203    };
204
205    // ── Exchange code + verifier at the OIDC provider ────────────────────
206    let tokens = match state
207        .oidc_client
208        .exchange_code(&code, &pkce.verifier, &state.http_client)
209        .await
210    {
211        Ok(t) => t,
212        Err(e) => {
213            tracing::error!("token exchange failed: {e}");
214            return auth_error(StatusCode::BAD_GATEWAY, "token exchange with OIDC provider failed");
215        },
216    };
217
218    // ── Return tokens ─────────────────────────────────────────────────────
219    if let Some(redirect_uri) = &state.post_login_redirect_uri {
220        // Browser flow: redirect to frontend, set token in HttpOnly cookie.
221        // The redirect target is server-configured (not from pkce.redirect_uri —
222        // IMPORTANT: pkce.redirect_uri MUST NOT be used to construct an HTTP
223        // redirect without allowlist validation; its value is caller-supplied
224        // and could be attacker-controlled).
225        //
226        // Cookie notes:
227        // - `__Host-` prefix mandates Secure, Path=/, no Domain, blocking subdomain override.
228        // - Token value is double-quoted (RFC 6265 quoted-string) to safely embed any printable
229        //   ASCII that OAuth servers may include.
230        // - Max-Age uses 300s when expires_in is absent — a conservative default that prevents the
231        //   cookie outliving a short-lived token by a large margin.
232        let max_age = tokens.expires_in.unwrap_or(300);
233        // Escape '"' and '\' inside the token value per RFC 6265 quoted-string rules.
234        let token_escaped = tokens.access_token.replace('\\', r"\\").replace('"', r#"\""#);
235        let cookie = format!(
236            r#"__Host-access_token="{token_escaped}"; Path=/; HttpOnly; Secure; SameSite=Strict; Max-Age={max_age}"#,
237        );
238        let mut resp = Redirect::to(redirect_uri).into_response();
239        match cookie.parse() {
240            Ok(value) => {
241                resp.headers_mut().insert(header::SET_COOKIE, value);
242            },
243            Err(e) => {
244                tracing::error!("Failed to parse Set-Cookie header: {e}");
245                return auth_error(
246                    StatusCode::INTERNAL_SERVER_ERROR,
247                    "session cookie could not be set",
248                );
249            },
250        }
251        resp
252    } else {
253        // API / native app flow: return tokens as JSON.
254        Json(TokenJson {
255            access_token: tokens.access_token,
256            id_token:     tokens.id_token,
257            expires_in:   tokens.expires_in,
258            token_type:   "Bearer",
259        })
260        .into_response()
261    }
262}
263
264// ---------------------------------------------------------------------------
265// POST /auth/revoke
266// ---------------------------------------------------------------------------
267
268/// Request body for token revocation.
269#[derive(Deserialize)]
270pub struct RevokeTokenRequest {
271    /// The JWT to revoke (we extract `jti` and `exp` from it).
272    pub token: String,
273}
274
275/// Response body for token revocation.
276#[derive(Serialize)]
277pub struct RevokeTokenResponse {
278    /// Whether the token was successfully revoked.
279    pub revoked:    bool,
280    /// ISO-8601 timestamp at which the revocation record will expire, if known.
281    #[serde(skip_serializing_if = "Option::is_none")]
282    pub expires_at: Option<String>,
283}
284
285/// Shared state for revocation routes.
286pub struct RevocationRouteState {
287    /// Token revocation manager used to record and check revoked JTIs.
288    pub revocation_manager: std::sync::Arc<crate::token_revocation::TokenRevocationManager>,
289}
290
291/// Revoke a single JWT by its `jti` claim.
292///
293/// The token is decoded (without verification — we only need the claims) to
294/// extract `jti` and `exp`.  The revocation entry TTL is set to the remaining
295/// token lifetime so the store auto-cleans.
296///
297/// # Responses
298///
299/// - `200` — token revoked successfully.
300/// - `400` — token is missing or has no `jti` claim.
301pub async fn revoke_token(
302    State(state): State<std::sync::Arc<RevocationRouteState>>,
303    Json(body): Json<RevokeTokenRequest>,
304) -> Response {
305    #[derive(serde::Deserialize)]
306    struct MinimalClaims {
307        jti: Option<String>,
308        exp: Option<u64>,
309    }
310
311    // Decode without signature verification — we only need the claims for revocation.
312    let claims = match jsonwebtoken::dangerous::insecure_decode::<MinimalClaims>(&body.token) {
313        Ok(data) => data.claims,
314        Err(e) => {
315            return auth_error(StatusCode::BAD_REQUEST, &format!("Invalid token: {e}"));
316        },
317    };
318
319    let jti = match claims.jti {
320        Some(j) if !j.is_empty() => j,
321        _ => {
322            return auth_error(StatusCode::BAD_REQUEST, "Token has no jti claim");
323        },
324    };
325
326    // TTL = remaining token lifetime, or 24h if no exp.
327    let ttl_secs = claims
328        .exp
329        .and_then(|exp| {
330            let now = chrono::Utc::now().timestamp().cast_unsigned();
331            exp.checked_sub(now)
332        })
333        .unwrap_or(86400);
334
335    if let Err(e) = state.revocation_manager.revoke(&jti, ttl_secs).await {
336        tracing::error!(error = %e, "Failed to revoke token");
337        return auth_error(StatusCode::INTERNAL_SERVER_ERROR, "Failed to revoke token");
338    }
339
340    let expires_at = claims.exp.map(|exp| {
341        chrono::DateTime::from_timestamp(exp.cast_signed(), 0)
342            .map_or_else(|| exp.to_string(), |dt| dt.to_rfc3339())
343    });
344
345    Json(RevokeTokenResponse {
346        revoked: true,
347        expires_at,
348    })
349    .into_response()
350}
351
352// ---------------------------------------------------------------------------
353// POST /auth/revoke-all
354// ---------------------------------------------------------------------------
355
356/// Request body for revoking all tokens for a user.
357#[derive(Deserialize)]
358pub struct RevokeAllRequest {
359    /// User subject (from JWT `sub` claim).
360    pub sub: String,
361}
362
363/// Response body for bulk revocation.
364#[derive(Serialize)]
365pub struct RevokeAllResponse {
366    /// Number of token revocation records that were created.
367    pub revoked_count: u64,
368}
369
370/// Revoke all tokens for a user.
371///
372/// # Responses
373///
374/// - `200` — tokens revoked.
375/// - `400` — `sub` is missing or empty.
376pub async fn revoke_all_tokens(
377    State(state): State<std::sync::Arc<RevocationRouteState>>,
378    Json(body): Json<RevokeAllRequest>,
379) -> Response {
380    if body.sub.is_empty() {
381        return auth_error(StatusCode::BAD_REQUEST, "sub is required");
382    }
383
384    match state.revocation_manager.revoke_all_for_user(&body.sub).await {
385        Ok(count) => Json(RevokeAllResponse {
386            revoked_count: count,
387        })
388        .into_response(),
389        Err(e) => {
390            tracing::error!(error = %e, sub = %body.sub, "Failed to revoke tokens for user");
391            auth_error(StatusCode::INTERNAL_SERVER_ERROR, "Failed to revoke tokens")
392        },
393    }
394}
395
396// ---------------------------------------------------------------------------
397// GET /auth/me
398// ---------------------------------------------------------------------------
399
400/// State for the [`auth_me`] handler, extracted from `[auth.me]` config.
401pub struct AuthMeState {
402    /// Raw JWT claim names that the handler should include in the response,
403    /// beyond the always-present `sub`, `user_id`, and `expires_at`.
404    pub expose_claims: Vec<String>,
405}
406
407/// Return the current session's identity as JSON.
408///
409/// Reads the [`crate::middleware::AuthUser`] request extension populated by
410/// `oidc_auth_middleware` and reflects a configurable subset of the validated
411/// JWT claims back to the caller.
412///
413/// The response always contains:
414/// - `sub` — the standard JWT subject (user ID).
415/// - `user_id` — hardcoded alias for `sub`; more ergonomic for frontend code.
416/// - `expires_at` — ISO-8601 timestamp when the session expires.
417///
418/// Additional fields are included only when (a) the claim name appears in the
419/// `expose_claims` allowlist **and** (b) the claim is present in the token.
420/// Claims in the allowlist but absent from the token are silently omitted —
421/// the response is never padded with `null` values.
422///
423/// The `user_id` alias for `sub` is always present and does **not** need to
424/// be listed in `expose_claims`.  Listing `"user_id"` there would silently
425/// return nothing because the JWT only carries `sub`, not `user_id`.
426///
427/// # Responses
428///
429/// - `200` JSON `{ sub, user_id, expires_at, ...expose_claims }`
430/// - `401` when no valid session is present (enforced by `oidc_auth_middleware` before this handler
431///   is called).
432pub async fn auth_me(
433    axum::extract::State(state): axum::extract::State<std::sync::Arc<AuthMeState>>,
434    axum::Extension(auth_user): axum::Extension<crate::middleware::AuthUser>,
435) -> axum::response::Response {
436    use axum::{Json, response::IntoResponse as _};
437
438    let user = &auth_user.0;
439
440    let mut map = serde_json::Map::new();
441    map.insert("sub".to_owned(), serde_json::Value::String(user.user_id.clone()));
442    map.insert("user_id".to_owned(), serde_json::Value::String(user.user_id.clone()));
443    map.insert("expires_at".to_owned(), serde_json::Value::String(user.expires_at.to_rfc3339()));
444
445    for claim_name in &state.expose_claims {
446        if let Some(value) = user.extra_claims.get(claim_name) {
447            map.insert(claim_name.clone(), value.clone());
448        }
449    }
450
451    Json(serde_json::Value::Object(map)).into_response()
452}
453
454// ---------------------------------------------------------------------------
455// Unit tests
456// ---------------------------------------------------------------------------
457
458#[cfg(test)]
459mod tests {
460    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
461
462    use axum::{Extension, Router, body::Body, http::Request, routing::get};
463    use chrono::Utc;
464    use tower::ServiceExt as _;
465
466    use super::*;
467    use crate::{auth::PkceStateStore, middleware::AuthUser};
468
469    fn mock_pkce_store() -> Arc<PkceStateStore> {
470        Arc::new(PkceStateStore::new(600, None))
471    }
472
473    // -------------------------------------------------------------------------
474    // auth_me tests
475    // -------------------------------------------------------------------------
476
477    fn make_auth_user(
478        user_id: &str,
479        extra: std::collections::HashMap<String, serde_json::Value>,
480    ) -> AuthUser {
481        AuthUser(fraiseql_core::security::AuthenticatedUser {
482            user_id:      user_id.to_owned(),
483            scopes:       vec![],
484            expires_at:   Utc::now() + chrono::Duration::hours(1),
485            extra_claims: extra,
486        })
487    }
488
489    fn make_me_state(expose_claims: Vec<&str>) -> Arc<AuthMeState> {
490        Arc::new(AuthMeState {
491            expose_claims: expose_claims.into_iter().map(str::to_owned).collect(),
492        })
493    }
494
495    #[tokio::test]
496    async fn test_auth_me_always_returns_sub_user_id_expires_at() {
497        let app = Router::new()
498            .route("/auth/me", get(auth_me))
499            .layer(Extension(make_auth_user("user-123", std::collections::HashMap::new())))
500            .with_state(make_me_state(vec![]));
501
502        let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
503        let resp = app.oneshot(req).await.unwrap();
504        assert_eq!(resp.status(), axum::http::StatusCode::OK);
505
506        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
507        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
508
509        assert_eq!(json["sub"], "user-123");
510        assert_eq!(json["user_id"], "user-123");
511        assert!(json["expires_at"].is_string(), "expires_at must be present");
512    }
513
514    #[tokio::test]
515    async fn test_auth_me_expose_claims_filters_correctly() {
516        let mut extra = std::collections::HashMap::new();
517        extra.insert("email".to_owned(), serde_json::json!("alice@example.com"));
518        extra.insert("https://myapp.com/role".to_owned(), serde_json::json!("admin"));
519
520        let app = Router::new()
521            .route("/auth/me", get(auth_me))
522            .layer(Extension(make_auth_user("alice", extra)))
523            .with_state(make_me_state(vec!["email"]));
524
525        let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
526        let resp = app.oneshot(req).await.unwrap();
527        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
528        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
529
530        assert_eq!(json["email"], "alice@example.com", "listed claim must appear");
531        assert!(json.get("https://myapp.com/role").is_none(), "unlisted claim must be absent");
532    }
533
534    #[tokio::test]
535    async fn test_auth_me_claim_absent_from_token_silently_omitted() {
536        // expose_claims lists "tenant_id" but the token doesn't have it.
537        let app = Router::new()
538            .route("/auth/me", get(auth_me))
539            .layer(Extension(make_auth_user("user-x", std::collections::HashMap::new())))
540            .with_state(make_me_state(vec!["tenant_id"]));
541
542        let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
543        let resp = app.oneshot(req).await.unwrap();
544        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
545        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
546
547        assert!(json.get("tenant_id").is_none(), "absent claim must not be null-padded");
548        // Fixed fields must still be present.
549        assert_eq!(json["sub"], "user-x");
550    }
551
552    #[tokio::test]
553    async fn test_auth_me_namespaced_claim_in_expose_claims() {
554        let mut extra = std::collections::HashMap::new();
555        extra.insert("https://myapp.com/role".to_owned(), serde_json::json!("editor"));
556
557        let app = Router::new()
558            .route("/auth/me", get(auth_me))
559            .layer(Extension(make_auth_user("user-y", extra)))
560            .with_state(make_me_state(vec!["https://myapp.com/role"]));
561
562        let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
563        let resp = app.oneshot(req).await.unwrap();
564        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
565        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
566
567        assert_eq!(json["https://myapp.com/role"], "editor");
568    }
569
570    fn mock_oidc_client() -> Arc<OidcServerClient> {
571        Arc::new(OidcServerClient::new(
572            "test-client",
573            "test-secret",
574            "https://api.example.com/auth/callback",
575            "https://provider.example.com/authorize",
576            "https://provider.example.com/token",
577        ))
578    }
579
580    fn auth_router() -> Router {
581        let auth_state = Arc::new(AuthPkceState {
582            pkce_store:              mock_pkce_store(),
583            oidc_client:             mock_oidc_client(),
584            http_client:             Arc::new(reqwest::Client::new()),
585            post_login_redirect_uri: None,
586        });
587        Router::new()
588            .route("/auth/start", get(auth_start))
589            .route("/auth/callback", get(auth_callback))
590            .with_state(auth_state)
591    }
592
593    #[tokio::test]
594    async fn test_auth_start_redirects_with_pkce_params() {
595        let app = auth_router();
596        let req = Request::builder()
597            .uri("/auth/start?redirect_uri=https%3A%2F%2Fapp.example.com%2Fcb")
598            .body(Body::empty())
599            .unwrap();
600        let resp = app.oneshot(req).await.unwrap();
601
602        // axum's Redirect::to() returns 303 See Other; allow any 3xx redirect.
603        assert!(resp.status().is_redirection(), "expected redirect, got {}", resp.status());
604        let location = resp
605            .headers()
606            .get(header::LOCATION)
607            .and_then(|v| v.to_str().ok())
608            .expect("Location header must be present");
609
610        assert!(location.contains("response_type=code"), "missing response_type");
611        assert!(location.contains("code_challenge="), "missing code_challenge");
612        assert!(location.contains("code_challenge_method=S256"), "missing challenge method");
613        assert!(location.contains("state="), "missing state param");
614        assert!(location.contains("client_id=test-client"), "missing client_id");
615    }
616
617    #[tokio::test]
618    async fn test_auth_start_missing_redirect_uri_returns_400() {
619        let app = auth_router();
620        let req = Request::builder().uri("/auth/start").body(Body::empty()).unwrap();
621        let resp = app.oneshot(req).await.unwrap();
622        // Missing required query param → axum returns 422 (or our guard returns 400).
623        // Either is acceptable; what matters is it's not 200 or 302.
624        assert!(
625            resp.status().is_client_error(),
626            "missing redirect_uri must be a client error, got {}",
627            resp.status()
628        );
629    }
630
631    #[tokio::test]
632    async fn test_auth_callback_unknown_state_returns_400() {
633        let app = auth_router();
634        let req = Request::builder()
635            .uri("/auth/callback?code=abc&state=completely-unknown-state")
636            .body(Body::empty())
637            .unwrap();
638        let resp = app.oneshot(req).await.unwrap();
639        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
640        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
641        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
642        // Client receives a generic error string, not an internal panic.
643        assert!(json["error"].is_string(), "error field must be a string: {json}");
644    }
645
646    #[tokio::test]
647    async fn test_auth_callback_missing_code_returns_400() {
648        let app = auth_router();
649        let req = Request::builder()
650            .uri("/auth/callback?state=some-state-no-code")
651            .body(Body::empty())
652            .unwrap();
653        let resp = app.oneshot(req).await.unwrap();
654        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
655    }
656
657    #[tokio::test]
658    async fn test_auth_start_oversized_redirect_uri_returns_400() {
659        let app = auth_router();
660        let long_uri = "https://example.com/".to_string() + &"a".repeat(2100);
661        let encoded = urlencoding::encode(&long_uri);
662        let req = Request::builder()
663            .uri(format!("/auth/start?redirect_uri={encoded}"))
664            .body(Body::empty())
665            .unwrap();
666        let resp = app.oneshot(req).await.unwrap();
667        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
668        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
669        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
670        assert!(
671            json["error"].as_str().unwrap_or("").contains("maximum length"),
672            "error must mention length: {json}"
673        );
674    }
675
676    #[tokio::test]
677    async fn test_auth_callback_oidc_error_returns_mapped_message() {
678        let app = auth_router();
679        // access_denied should map to a fixed message, not reflect provider strings
680        let req = Request::builder()
681            .uri("/auth/callback?error=access_denied&error_description=internal+tenant+info")
682            .body(Body::empty())
683            .unwrap();
684        let resp = app.oneshot(req).await.unwrap();
685        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
686        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
687        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
688        let error_msg = json["error"].as_str().unwrap_or("");
689        // Must not contain the raw provider description
690        assert!(
691            !error_msg.contains("internal tenant info"),
692            "provider description must not be reflected to client: {error_msg}"
693        );
694        assert_eq!(error_msg, "Access was denied");
695    }
696
697    #[tokio::test]
698    async fn test_auth_callback_unknown_oidc_error_returns_generic_message() {
699        let app = auth_router();
700        let req = Request::builder()
701            .uri("/auth/callback?error=unknown_vendor_error&error_description=secret+details")
702            .body(Body::empty())
703            .unwrap();
704        let resp = app.oneshot(req).await.unwrap();
705        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
706        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
707        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
708        assert_eq!(json["error"].as_str().unwrap_or(""), "Authorization failed");
709    }
710
711    #[tokio::test]
712    async fn test_auth_callback_oidc_error_no_description_uses_fallback() {
713        let app = auth_router();
714        let req = Request::builder()
715            .uri("/auth/callback?error=access_denied")
716            .body(Body::empty())
717            .unwrap();
718        let resp = app.oneshot(req).await.unwrap();
719        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
720        // The tracing log (not the HTTP response) includes the desc; the HTTP
721        // response is the sanitised allowlist message. We verify the handler does
722        // not panic and returns the mapped client message.
723        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
724        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
725        assert_eq!(json["error"].as_str().unwrap_or(""), "Access was denied");
726    }
727
728    /// Full HTTP-level PKCE round-trip: `/auth/start` → extract state → `/auth/callback`.
729    ///
730    /// Verifies that the state token embedded in the `/auth/start` redirect can be
731    /// submitted to `/auth/callback`, proving the PKCE store correctly survives the
732    /// round-trip through the HTTP layer (including encryption when enabled).
733    ///
734    /// The callback will fail at token exchange (no real OIDC provider) and return 502,
735    /// but NOT 400 — a 400 would indicate the state was not found in the store.
736    #[tokio::test]
737    async fn test_auth_start_to_callback_state_roundtrip_with_encryption() {
738        use crate::auth::{EncryptionAlgorithm, StateEncryptionService};
739
740        let enc = Arc::new(StateEncryptionService::from_raw_key(
741            &[0u8; 32],
742            EncryptionAlgorithm::Chacha20Poly1305,
743        ));
744        let pkce_store = Arc::new(PkceStateStore::new(600, Some(enc)));
745
746        let auth_state = Arc::new(AuthPkceState {
747            pkce_store,
748            oidc_client: mock_oidc_client(),
749            http_client: Arc::new(reqwest::Client::new()),
750            post_login_redirect_uri: None,
751        });
752
753        let app = Router::new()
754            .route("/auth/start", get(auth_start))
755            .route("/auth/callback", get(auth_callback))
756            .with_state(auth_state);
757
758        // Step 1 — /auth/start: receive redirect containing the encrypted state token.
759        let req = Request::builder()
760            .uri("/auth/start?redirect_uri=https%3A%2F%2Fapp.example.com%2Fcb")
761            .body(Body::empty())
762            .unwrap();
763        let resp = app.clone().oneshot(req).await.unwrap();
764
765        assert!(
766            resp.status().is_redirection(),
767            "expected redirect from /auth/start, got {}",
768            resp.status(),
769        );
770
771        let location = resp
772            .headers()
773            .get(header::LOCATION)
774            .and_then(|v| v.to_str().ok())
775            .expect("Location header must be set")
776            .to_string();
777
778        // Extract the state= token from the redirect URL using proper URL parsing to
779        // avoid false matches when "state=" appears elsewhere in the URL (e.g. in path
780        // or other parameters).
781        let parsed_location =
782            reqwest::Url::parse(&location).expect("Location header must be a valid URL");
783        let state_token = parsed_location
784            .query_pairs()
785            .find(|(k, _)| k == "state")
786            .map(|(_, v)| v.into_owned())
787            .expect("state= must appear in the redirect Location URL");
788
789        assert!(!state_token.is_empty(), "extracted state token must not be empty");
790
791        // Step 2 — /auth/callback: submit the real state token from step 1.
792        // Expected result: 502 Bad Gateway (token exchange fails — no real OIDC provider).
793        // A 400 would mean the PKCE state was not found, which would be a regression.
794        let callback_uri = format!("/auth/callback?code=test_code&state={state_token}");
795        let req2 = Request::builder().uri(&callback_uri).body(Body::empty()).unwrap();
796        let resp2 = app.clone().oneshot(req2).await.unwrap();
797
798        assert_ne!(
799            resp2.status(),
800            StatusCode::BAD_REQUEST,
801            "state from /auth/start must be accepted by /auth/callback; \
802             400 means the PKCE state was not found or decryption failed",
803        );
804        assert_eq!(
805            resp2.status(),
806            StatusCode::BAD_GATEWAY,
807            "token exchange should fail 502 (no real OIDC provider); got {}",
808            resp2.status(),
809        );
810    }
811}