Skip to main content

fraiseql_server/routes/
auth.rs

1//! PKCE `OAuth2` route handlers: `/auth/start` and `/auth/callback`.
2//!
3//! These routes implement the `OAuth2` Authorization Code flow with PKCE
4//! (RFC 7636) for server-side relying-party use.  FraiseQL acts as the
5//! OAuth client; the OIDC provider performs the actual authentication.
6//!
7//! # Flow
8//!
9//! ```text
10//! GET /auth/start?redirect_uri=https://app.example.com/after-login
11//!   → 302 → OIDC provider /authorize?...&code_challenge=...&state=...
12//!
13//! GET /auth/callback?code=<code>&state=<state>
14//!   → [verify state, exchange code+verifier for tokens]
15//!   → 200 JSON { access_token, id_token, expires_in, token_type }
16//!   OR 302 + Set-Cookie (when post_login_redirect_uri is configured)
17//! ```
18//!
19//! Routes are only mounted when `[security.pkce] enabled = true` AND `[auth]`
20//! is configured in the compiled schema.  See `server.rs` for the wiring.
21
22use std::sync::Arc;
23
24use axum::{
25    Json,
26    extract::{Query, State},
27    http::{StatusCode, header},
28    response::{IntoResponse, Redirect, Response},
29};
30use serde::{Deserialize, Serialize};
31
32use crate::auth::{OidcServerClient, PkceStateStore};
33
34/// Shared state injected into both PKCE route handlers.
35pub struct AuthPkceState {
36    /// In-memory PKCE state store (encrypted when `state_encryption` is on).
37    pub pkce_store:              Arc<PkceStateStore>,
38    /// Server-side OIDC client for building authorize URLs and exchanging codes.
39    pub oidc_client:             Arc<OidcServerClient>,
40    /// Shared HTTP client for token-endpoint calls.
41    pub http_client:             Arc<reqwest::Client>,
42    /// When set, the callback redirects here with the token in a
43    /// `Secure; HttpOnly; SameSite=Strict` cookie instead of returning JSON.
44    pub post_login_redirect_uri: Option<String>,
45}
46
47// ---------------------------------------------------------------------------
48// Query parameter structs
49// ---------------------------------------------------------------------------
50
51/// Query parameters accepted by `GET /auth/start`.
52#[derive(Deserialize)]
53pub struct AuthStartQuery {
54    /// The URI within the **client application** to redirect to after a
55    /// successful login.  This is stored in the PKCE state store and
56    /// returned to the caller at callback time via the `redirect_uri` in
57    /// the consumed state.
58    redirect_uri: String,
59}
60
61/// Query parameters sent by the OIDC provider to `GET /auth/callback`.
62#[derive(Deserialize)]
63pub struct AuthCallbackQuery {
64    /// Authorization code to exchange for tokens.
65    code:              Option<String>,
66    /// State token for CSRF and PKCE state lookup.
67    state:             Option<String>,
68    /// OIDC provider error code (e.g. `"access_denied"`).
69    error:             Option<String>,
70    /// Human-readable error description from the provider.
71    error_description: Option<String>,
72}
73
74// ---------------------------------------------------------------------------
75// Response body (JSON path)
76// ---------------------------------------------------------------------------
77
78#[derive(Serialize)]
79struct TokenJson {
80    access_token: String,
81    #[serde(skip_serializing_if = "Option::is_none")]
82    id_token:     Option<String>,
83    #[serde(skip_serializing_if = "Option::is_none")]
84    expires_in:   Option<u64>,
85    token_type:   &'static str,
86}
87
88// ---------------------------------------------------------------------------
89// Helpers
90// ---------------------------------------------------------------------------
91
92fn auth_error(status: StatusCode, message: &str) -> Response {
93    (status, Json(serde_json::json!({ "error": message }))).into_response()
94}
95
96// ---------------------------------------------------------------------------
97// GET /auth/start
98// ---------------------------------------------------------------------------
99
100/// Initiate a PKCE authorization code flow.
101///
102/// Generates a `code_verifier` and `code_challenge`, stores state in the
103/// [`PkceStateStore`], then redirects the user-agent to the OIDC provider.
104///
105/// # Query parameters
106///
107/// - `redirect_uri` — **required**: the client application's callback URI.
108///
109/// # Responses
110///
111/// - `302` — redirect to the OIDC provider's `/authorize` endpoint.
112/// - `400` — `redirect_uri` is missing.
113/// - `500` — internal error generating state (essentially impossible).
114pub async fn auth_start(
115    State(state): State<Arc<AuthPkceState>>,
116    Query(q): Query<AuthStartQuery>,
117) -> Response {
118    if q.redirect_uri.is_empty() {
119        return auth_error(StatusCode::BAD_REQUEST, "redirect_uri is required");
120    }
121    // Enforce a length cap to prevent memory amplification via the PKCE state store
122    // (in-memory or Redis) and to limit encrypted state blob size.
123    if q.redirect_uri.len() > 2048 {
124        return auth_error(StatusCode::BAD_REQUEST, "redirect_uri exceeds maximum length");
125    }
126
127    let (outbound_token, verifier) = match state.pkce_store.create_state(&q.redirect_uri).await {
128        Ok(v) => v,
129        Err(e) => {
130            tracing::error!("pkce create_state failed: {e}");
131            return auth_error(
132                StatusCode::INTERNAL_SERVER_ERROR,
133                "authorization flow could not be started",
134            );
135        },
136    };
137
138    let challenge = PkceStateStore::s256_challenge(&verifier);
139    let location = state.oidc_client.authorization_url(&outbound_token, &challenge, "S256");
140
141    Redirect::to(&location).into_response()
142}
143
144// ---------------------------------------------------------------------------
145// GET /auth/callback
146// ---------------------------------------------------------------------------
147
148/// Complete the PKCE authorization code flow.
149///
150/// Validates the `state` parameter, recovers the `code_verifier`, then
151/// exchanges the authorization `code` at the OIDC token endpoint.
152///
153/// # Query parameters
154///
155/// - `code`  — authorization code from the provider.
156/// - `state` — state token (may be encrypted).
157///
158/// The provider may also call this endpoint with `?error=…` when the user
159/// denies access; those are surfaced as `400` responses.
160///
161/// # Responses
162///
163/// - `200` JSON `{ access_token, id_token?, expires_in?, token_type }`. Or `302` with `Set-Cookie`
164///   when `post_login_redirect_uri` is configured.
165/// - `400` — invalid/expired state, missing parameters, or provider error.
166/// - `502` — token exchange with the OIDC provider failed.
167#[allow(clippy::cognitive_complexity)] // Reason: OAuth callback handler with state validation, token exchange, and redirect logic
168pub async fn auth_callback(
169    State(state): State<Arc<AuthPkceState>>,
170    Query(q): Query<AuthCallbackQuery>,
171) -> Response {
172    // ── Surface OIDC provider errors immediately ──────────────────────────
173    if let Some(err) = q.error {
174        let desc = q.error_description.as_deref().unwrap_or("(no description provided)");
175        // Log the full provider response for debugging, but return only a
176        // fixed allowlisted message to the client to avoid leaking internal
177        // provider details (tenant info, stack traces) or enabling injection.
178        tracing::warn!(oidc_error = %err, description = %desc, "OIDC provider returned error");
179        let client_message = match err.as_str() {
180            "access_denied" => "Access was denied",
181            "login_required" => "Authentication is required",
182            "invalid_request" | "invalid_scope" => "Invalid authorization request",
183            "server_error" | "temporarily_unavailable" => "Authorization server error",
184            _ => "Authorization failed",
185        };
186        return auth_error(StatusCode::BAD_REQUEST, client_message);
187    }
188
189    // ── Validate required parameters ──────────────────────────────────────
190    let (Some(code), Some(state_token)) = (q.code, q.state) else {
191        return auth_error(StatusCode::BAD_REQUEST, "missing code or state parameter");
192    };
193
194    // ── Consume PKCE state (atomic remove) ───────────────────────────────
195    let pkce = match state.pkce_store.consume_state(&state_token).await {
196        Ok(s) => s,
197        Err(e) => {
198            // Both StateNotFound and StateExpired are client errors.
199            // Log at debug to avoid spamming warnings from probing attacks.
200            tracing::debug!(error = %e, "pkce consume_state failed");
201            return auth_error(StatusCode::BAD_REQUEST, &e.to_string());
202        },
203    };
204
205    // ── Exchange code + verifier at the OIDC provider ────────────────────
206    let tokens = match state
207        .oidc_client
208        .exchange_code(&code, &pkce.verifier, &state.http_client)
209        .await
210    {
211        Ok(t) => t,
212        Err(e) => {
213            tracing::error!("token exchange failed: {e}");
214            return auth_error(StatusCode::BAD_GATEWAY, "token exchange with OIDC provider failed");
215        },
216    };
217
218    // ── Return tokens ─────────────────────────────────────────────────────
219    if let Some(redirect_uri) = &state.post_login_redirect_uri {
220        // Browser flow: redirect to frontend, set token in HttpOnly cookie.
221        // The redirect target is server-configured (not from pkce.redirect_uri —
222        // IMPORTANT: pkce.redirect_uri MUST NOT be used to construct an HTTP
223        // redirect without allowlist validation; its value is caller-supplied
224        // and could be attacker-controlled).
225        //
226        // Cookie notes:
227        // - `__Host-` prefix mandates Secure, Path=/, no Domain, blocking subdomain override.
228        // - Token value is double-quoted (RFC 6265 quoted-string) to safely embed any printable
229        //   ASCII that OAuth servers may include.
230        // - Max-Age uses 300s when expires_in is absent — a conservative default that prevents the
231        //   cookie outliving a short-lived token by a large margin.
232        let max_age = tokens.expires_in.unwrap_or(300);
233        // Escape '"' and '\' inside the token value per RFC 6265 quoted-string rules.
234        let token_escaped = tokens.access_token.replace('\\', r"\\").replace('"', r#"\""#);
235        let cookie = format!(
236            r#"__Host-access_token="{token_escaped}"; Path=/; HttpOnly; Secure; SameSite=Strict; Max-Age={max_age}"#,
237        );
238        let mut resp = Redirect::to(redirect_uri).into_response();
239        match cookie.parse() {
240            Ok(value) => {
241                resp.headers_mut().insert(header::SET_COOKIE, value);
242            },
243            Err(e) => {
244                tracing::error!("Failed to parse Set-Cookie header: {e}");
245                return auth_error(
246                    StatusCode::INTERNAL_SERVER_ERROR,
247                    "session cookie could not be set",
248                );
249            },
250        }
251        resp
252    } else {
253        // API / native app flow: return tokens as JSON.
254        Json(TokenJson {
255            access_token: tokens.access_token,
256            id_token:     tokens.id_token,
257            expires_in:   tokens.expires_in,
258            token_type:   "Bearer",
259        })
260        .into_response()
261    }
262}
263
264// ---------------------------------------------------------------------------
265// POST /auth/revoke
266// ---------------------------------------------------------------------------
267
268/// Request body for token revocation.
269#[derive(Deserialize)]
270pub struct RevokeTokenRequest {
271    /// The JWT to revoke (we extract `jti` and `exp` from it).
272    pub token: String,
273}
274
275/// Response body for token revocation.
276#[derive(Serialize)]
277pub struct RevokeTokenResponse {
278    /// Whether the token was successfully revoked.
279    pub revoked:    bool,
280    /// ISO-8601 timestamp at which the revocation record will expire, if known.
281    #[serde(skip_serializing_if = "Option::is_none")]
282    pub expires_at: Option<String>,
283}
284
285/// Shared state for revocation routes.
286pub struct RevocationRouteState {
287    /// Token revocation manager used to record and check revoked JTIs.
288    pub revocation_manager: std::sync::Arc<crate::token_revocation::TokenRevocationManager>,
289}
290
291/// Revoke a single JWT by its `jti` claim.
292///
293/// The token is decoded (without verification — we only need the claims) to
294/// extract `jti` and `exp`.  The revocation entry TTL is set to the remaining
295/// token lifetime so the store auto-cleans.
296///
297/// # Responses
298///
299/// - `200` — token revoked successfully.
300/// - `400` — token is missing or has no `jti` claim.
301pub async fn revoke_token(
302    State(state): State<std::sync::Arc<RevocationRouteState>>,
303    Json(body): Json<RevokeTokenRequest>,
304) -> Response {
305    #[derive(serde::Deserialize)]
306    struct MinimalClaims {
307        jti: Option<String>,
308        exp: Option<u64>,
309    }
310
311    // Decode without signature verification — we only need the claims for revocation.
312    let claims = match jsonwebtoken::dangerous::insecure_decode::<MinimalClaims>(&body.token) {
313        Ok(data) => data.claims,
314        Err(e) => {
315            return auth_error(StatusCode::BAD_REQUEST, &format!("Invalid token: {e}"));
316        },
317    };
318
319    let jti = match claims.jti {
320        Some(j) if !j.is_empty() => j,
321        _ => {
322            return auth_error(StatusCode::BAD_REQUEST, "Token has no jti claim");
323        },
324    };
325
326    // TTL = remaining token lifetime, or 24h if no exp.
327    let ttl_secs = claims
328        .exp
329        .and_then(|exp| {
330            let now = chrono::Utc::now().timestamp().cast_unsigned();
331            exp.checked_sub(now)
332        })
333        .unwrap_or(86400);
334
335    if let Err(e) = state.revocation_manager.revoke(&jti, ttl_secs).await {
336        tracing::error!(error = %e, "Failed to revoke token");
337        return auth_error(StatusCode::INTERNAL_SERVER_ERROR, "Failed to revoke token");
338    }
339
340    let expires_at = claims.exp.map(|exp| {
341        chrono::DateTime::from_timestamp(exp.cast_signed(), 0)
342            .map_or_else(|| exp.to_string(), |dt| dt.to_rfc3339())
343    });
344
345    Json(RevokeTokenResponse {
346        revoked: true,
347        expires_at,
348    })
349    .into_response()
350}
351
352// ---------------------------------------------------------------------------
353// POST /auth/revoke-all
354// ---------------------------------------------------------------------------
355
356/// Request body for revoking all tokens for a user.
357#[derive(Deserialize)]
358pub struct RevokeAllRequest {
359    /// User subject (from JWT `sub` claim).
360    pub sub: String,
361}
362
363/// Response body for bulk revocation.
364#[derive(Serialize)]
365pub struct RevokeAllResponse {
366    /// Number of token revocation records that were created.
367    pub revoked_count: u64,
368}
369
370/// Revoke all tokens for a user.
371///
372/// # Responses
373///
374/// - `200` — tokens revoked.
375/// - `400` — `sub` is missing or empty.
376pub async fn revoke_all_tokens(
377    State(state): State<std::sync::Arc<RevocationRouteState>>,
378    Json(body): Json<RevokeAllRequest>,
379) -> Response {
380    if body.sub.is_empty() {
381        return auth_error(StatusCode::BAD_REQUEST, "sub is required");
382    }
383
384    match state.revocation_manager.revoke_all_for_user(&body.sub).await {
385        Ok(count) => Json(RevokeAllResponse {
386            revoked_count: count,
387        })
388        .into_response(),
389        Err(e) => {
390            tracing::error!(error = %e, sub = %body.sub, "Failed to revoke tokens for user");
391            auth_error(StatusCode::INTERNAL_SERVER_ERROR, "Failed to revoke tokens")
392        },
393    }
394}
395
396// ---------------------------------------------------------------------------
397// GET /auth/me
398// ---------------------------------------------------------------------------
399
400/// State for the [`auth_me`] handler, extracted from `[auth.me]` config.
401pub struct AuthMeState {
402    /// Raw JWT claim names that the handler should include in the response,
403    /// beyond the always-present `sub`, `user_id`, and `expires_at`.
404    pub expose_claims: Vec<String>,
405}
406
407/// Return the current session's identity as JSON.
408///
409/// Reads the [`crate::middleware::AuthUser`] request extension populated by
410/// `oidc_auth_middleware` and reflects a configurable subset of the validated
411/// JWT claims back to the caller.
412///
413/// The response always contains:
414/// - `sub` — the standard JWT subject (user ID).
415/// - `user_id` — hardcoded alias for `sub`; more ergonomic for frontend code.
416/// - `expires_at` — ISO-8601 timestamp when the session expires.
417///
418/// Additional fields are included only when (a) the claim name appears in the
419/// `expose_claims` allowlist **and** (b) the claim is present in the token.
420/// Claims in the allowlist but absent from the token are silently omitted —
421/// the response is never padded with `null` values.
422///
423/// The `user_id` alias for `sub` is always present and does **not** need to
424/// be listed in `expose_claims`.  Listing `"user_id"` there would silently
425/// return nothing because the JWT only carries `sub`, not `user_id`.
426///
427/// # Responses
428///
429/// - `200` JSON `{ sub, user_id, expires_at, ...expose_claims }`
430/// - `401` when no valid session is present (enforced by `oidc_auth_middleware`
431///   before this handler is called).
432pub async fn auth_me(
433    axum::extract::State(state): axum::extract::State<std::sync::Arc<AuthMeState>>,
434    axum::Extension(auth_user): axum::Extension<crate::middleware::AuthUser>,
435) -> axum::response::Response {
436    use axum::Json;
437    use axum::response::IntoResponse as _;
438
439    let user = &auth_user.0;
440
441    let mut map = serde_json::Map::new();
442    map.insert("sub".to_owned(), serde_json::Value::String(user.user_id.clone()));
443    map.insert("user_id".to_owned(), serde_json::Value::String(user.user_id.clone()));
444    map.insert(
445        "expires_at".to_owned(),
446        serde_json::Value::String(user.expires_at.to_rfc3339()),
447    );
448
449    for claim_name in &state.expose_claims {
450        if let Some(value) = user.extra_claims.get(claim_name) {
451            map.insert(claim_name.clone(), value.clone());
452        }
453    }
454
455    Json(serde_json::Value::Object(map)).into_response()
456}
457
458// ---------------------------------------------------------------------------
459// Unit tests
460// ---------------------------------------------------------------------------
461
462#[cfg(test)]
463mod tests {
464    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
465
466    use axum::{Extension, Router, body::Body, http::Request, routing::get};
467    use chrono::Utc;
468    use tower::ServiceExt as _;
469
470    use super::*;
471    use crate::auth::PkceStateStore;
472    use crate::middleware::AuthUser;
473
474    fn mock_pkce_store() -> Arc<PkceStateStore> {
475        Arc::new(PkceStateStore::new(600, None))
476    }
477
478    // -------------------------------------------------------------------------
479    // auth_me tests
480    // -------------------------------------------------------------------------
481
482    fn make_auth_user(
483        user_id: &str,
484        extra: std::collections::HashMap<String, serde_json::Value>,
485    ) -> AuthUser {
486        AuthUser(fraiseql_core::security::AuthenticatedUser {
487            user_id:      user_id.to_owned(),
488            scopes:       vec![],
489            expires_at:   Utc::now() + chrono::Duration::hours(1),
490            extra_claims: extra,
491        })
492    }
493
494    fn make_me_state(expose_claims: Vec<&str>) -> Arc<AuthMeState> {
495        Arc::new(AuthMeState {
496            expose_claims: expose_claims.into_iter().map(str::to_owned).collect(),
497        })
498    }
499
500    #[tokio::test]
501    async fn test_auth_me_always_returns_sub_user_id_expires_at() {
502        let app = Router::new()
503            .route("/auth/me", get(auth_me))
504            .layer(Extension(make_auth_user("user-123", std::collections::HashMap::new())))
505            .with_state(make_me_state(vec![]));
506
507        let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
508        let resp = app.oneshot(req).await.unwrap();
509        assert_eq!(resp.status(), axum::http::StatusCode::OK);
510
511        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
512        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
513
514        assert_eq!(json["sub"], "user-123");
515        assert_eq!(json["user_id"], "user-123");
516        assert!(json["expires_at"].is_string(), "expires_at must be present");
517    }
518
519    #[tokio::test]
520    async fn test_auth_me_expose_claims_filters_correctly() {
521        let mut extra = std::collections::HashMap::new();
522        extra.insert("email".to_owned(), serde_json::json!("alice@example.com"));
523        extra.insert(
524            "https://myapp.com/role".to_owned(),
525            serde_json::json!("admin"),
526        );
527
528        let app = Router::new()
529            .route("/auth/me", get(auth_me))
530            .layer(Extension(make_auth_user("alice", extra)))
531            .with_state(make_me_state(vec!["email"]));
532
533        let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
534        let resp = app.oneshot(req).await.unwrap();
535        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
536        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
537
538        assert_eq!(json["email"], "alice@example.com", "listed claim must appear");
539        assert!(
540            json.get("https://myapp.com/role").is_none(),
541            "unlisted claim must be absent"
542        );
543    }
544
545    #[tokio::test]
546    async fn test_auth_me_claim_absent_from_token_silently_omitted() {
547        // expose_claims lists "tenant_id" but the token doesn't have it.
548        let app = Router::new()
549            .route("/auth/me", get(auth_me))
550            .layer(Extension(make_auth_user("user-x", std::collections::HashMap::new())))
551            .with_state(make_me_state(vec!["tenant_id"]));
552
553        let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
554        let resp = app.oneshot(req).await.unwrap();
555        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
556        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
557
558        assert!(json.get("tenant_id").is_none(), "absent claim must not be null-padded");
559        // Fixed fields must still be present.
560        assert_eq!(json["sub"], "user-x");
561    }
562
563    #[tokio::test]
564    async fn test_auth_me_namespaced_claim_in_expose_claims() {
565        let mut extra = std::collections::HashMap::new();
566        extra.insert("https://myapp.com/role".to_owned(), serde_json::json!("editor"));
567
568        let app = Router::new()
569            .route("/auth/me", get(auth_me))
570            .layer(Extension(make_auth_user("user-y", extra)))
571            .with_state(make_me_state(vec!["https://myapp.com/role"]));
572
573        let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
574        let resp = app.oneshot(req).await.unwrap();
575        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
576        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
577
578        assert_eq!(json["https://myapp.com/role"], "editor");
579    }
580
581    fn mock_oidc_client() -> Arc<OidcServerClient> {
582        Arc::new(OidcServerClient::new(
583            "test-client",
584            "test-secret",
585            "https://api.example.com/auth/callback",
586            "https://provider.example.com/authorize",
587            "https://provider.example.com/token",
588        ))
589    }
590
591    fn auth_router() -> Router {
592        let auth_state = Arc::new(AuthPkceState {
593            pkce_store:              mock_pkce_store(),
594            oidc_client:             mock_oidc_client(),
595            http_client:             Arc::new(reqwest::Client::new()),
596            post_login_redirect_uri: None,
597        });
598        Router::new()
599            .route("/auth/start", get(auth_start))
600            .route("/auth/callback", get(auth_callback))
601            .with_state(auth_state)
602    }
603
604    #[tokio::test]
605    async fn test_auth_start_redirects_with_pkce_params() {
606        let app = auth_router();
607        let req = Request::builder()
608            .uri("/auth/start?redirect_uri=https%3A%2F%2Fapp.example.com%2Fcb")
609            .body(Body::empty())
610            .unwrap();
611        let resp = app.oneshot(req).await.unwrap();
612
613        // axum's Redirect::to() returns 303 See Other; allow any 3xx redirect.
614        assert!(resp.status().is_redirection(), "expected redirect, got {}", resp.status());
615        let location = resp
616            .headers()
617            .get(header::LOCATION)
618            .and_then(|v| v.to_str().ok())
619            .expect("Location header must be present");
620
621        assert!(location.contains("response_type=code"), "missing response_type");
622        assert!(location.contains("code_challenge="), "missing code_challenge");
623        assert!(location.contains("code_challenge_method=S256"), "missing challenge method");
624        assert!(location.contains("state="), "missing state param");
625        assert!(location.contains("client_id=test-client"), "missing client_id");
626    }
627
628    #[tokio::test]
629    async fn test_auth_start_missing_redirect_uri_returns_400() {
630        let app = auth_router();
631        let req = Request::builder().uri("/auth/start").body(Body::empty()).unwrap();
632        let resp = app.oneshot(req).await.unwrap();
633        // Missing required query param → axum returns 422 (or our guard returns 400).
634        // Either is acceptable; what matters is it's not 200 or 302.
635        assert!(
636            resp.status().is_client_error(),
637            "missing redirect_uri must be a client error, got {}",
638            resp.status()
639        );
640    }
641
642    #[tokio::test]
643    async fn test_auth_callback_unknown_state_returns_400() {
644        let app = auth_router();
645        let req = Request::builder()
646            .uri("/auth/callback?code=abc&state=completely-unknown-state")
647            .body(Body::empty())
648            .unwrap();
649        let resp = app.oneshot(req).await.unwrap();
650        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
651        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
652        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
653        // Client receives a generic error string, not an internal panic.
654        assert!(json["error"].is_string(), "error field must be a string: {json}");
655    }
656
657    #[tokio::test]
658    async fn test_auth_callback_missing_code_returns_400() {
659        let app = auth_router();
660        let req = Request::builder()
661            .uri("/auth/callback?state=some-state-no-code")
662            .body(Body::empty())
663            .unwrap();
664        let resp = app.oneshot(req).await.unwrap();
665        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
666    }
667
668    #[tokio::test]
669    async fn test_auth_start_oversized_redirect_uri_returns_400() {
670        let app = auth_router();
671        let long_uri = "https://example.com/".to_string() + &"a".repeat(2100);
672        let encoded = urlencoding::encode(&long_uri);
673        let req = Request::builder()
674            .uri(format!("/auth/start?redirect_uri={encoded}"))
675            .body(Body::empty())
676            .unwrap();
677        let resp = app.oneshot(req).await.unwrap();
678        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
679        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
680        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
681        assert!(
682            json["error"].as_str().unwrap_or("").contains("maximum length"),
683            "error must mention length: {json}"
684        );
685    }
686
687    #[tokio::test]
688    async fn test_auth_callback_oidc_error_returns_mapped_message() {
689        let app = auth_router();
690        // access_denied should map to a fixed message, not reflect provider strings
691        let req = Request::builder()
692            .uri("/auth/callback?error=access_denied&error_description=internal+tenant+info")
693            .body(Body::empty())
694            .unwrap();
695        let resp = app.oneshot(req).await.unwrap();
696        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
697        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
698        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
699        let error_msg = json["error"].as_str().unwrap_or("");
700        // Must not contain the raw provider description
701        assert!(
702            !error_msg.contains("internal tenant info"),
703            "provider description must not be reflected to client: {error_msg}"
704        );
705        assert_eq!(error_msg, "Access was denied");
706    }
707
708    #[tokio::test]
709    async fn test_auth_callback_unknown_oidc_error_returns_generic_message() {
710        let app = auth_router();
711        let req = Request::builder()
712            .uri("/auth/callback?error=unknown_vendor_error&error_description=secret+details")
713            .body(Body::empty())
714            .unwrap();
715        let resp = app.oneshot(req).await.unwrap();
716        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
717        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
718        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
719        assert_eq!(json["error"].as_str().unwrap_or(""), "Authorization failed");
720    }
721
722    #[tokio::test]
723    async fn test_auth_callback_oidc_error_no_description_uses_fallback() {
724        let app = auth_router();
725        let req = Request::builder()
726            .uri("/auth/callback?error=access_denied")
727            .body(Body::empty())
728            .unwrap();
729        let resp = app.oneshot(req).await.unwrap();
730        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
731        // The tracing log (not the HTTP response) includes the desc; the HTTP
732        // response is the sanitised allowlist message. We verify the handler does
733        // not panic and returns the mapped client message.
734        let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
735        let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
736        assert_eq!(json["error"].as_str().unwrap_or(""), "Access was denied");
737    }
738
739    /// Full HTTP-level PKCE round-trip: `/auth/start` → extract state → `/auth/callback`.
740    ///
741    /// Verifies that the state token embedded in the `/auth/start` redirect can be
742    /// submitted to `/auth/callback`, proving the PKCE store correctly survives the
743    /// round-trip through the HTTP layer (including encryption when enabled).
744    ///
745    /// The callback will fail at token exchange (no real OIDC provider) and return 502,
746    /// but NOT 400 — a 400 would indicate the state was not found in the store.
747    #[tokio::test]
748    async fn test_auth_start_to_callback_state_roundtrip_with_encryption() {
749        use crate::auth::{EncryptionAlgorithm, StateEncryptionService};
750
751        let enc = Arc::new(StateEncryptionService::from_raw_key(
752            &[0u8; 32],
753            EncryptionAlgorithm::Chacha20Poly1305,
754        ));
755        let pkce_store = Arc::new(PkceStateStore::new(600, Some(enc)));
756
757        let auth_state = Arc::new(AuthPkceState {
758            pkce_store,
759            oidc_client: mock_oidc_client(),
760            http_client: Arc::new(reqwest::Client::new()),
761            post_login_redirect_uri: None,
762        });
763
764        let app = Router::new()
765            .route("/auth/start", get(auth_start))
766            .route("/auth/callback", get(auth_callback))
767            .with_state(auth_state);
768
769        // Step 1 — /auth/start: receive redirect containing the encrypted state token.
770        let req = Request::builder()
771            .uri("/auth/start?redirect_uri=https%3A%2F%2Fapp.example.com%2Fcb")
772            .body(Body::empty())
773            .unwrap();
774        let resp = app.clone().oneshot(req).await.unwrap();
775
776        assert!(
777            resp.status().is_redirection(),
778            "expected redirect from /auth/start, got {}",
779            resp.status(),
780        );
781
782        let location = resp
783            .headers()
784            .get(header::LOCATION)
785            .and_then(|v| v.to_str().ok())
786            .expect("Location header must be set")
787            .to_string();
788
789        // Extract the state= token from the redirect URL using proper URL parsing to
790        // avoid false matches when "state=" appears elsewhere in the URL (e.g. in path
791        // or other parameters).
792        let parsed_location =
793            reqwest::Url::parse(&location).expect("Location header must be a valid URL");
794        let state_token = parsed_location
795            .query_pairs()
796            .find(|(k, _)| k == "state")
797            .map(|(_, v)| v.into_owned())
798            .expect("state= must appear in the redirect Location URL");
799
800        assert!(!state_token.is_empty(), "extracted state token must not be empty");
801
802        // Step 2 — /auth/callback: submit the real state token from step 1.
803        // Expected result: 502 Bad Gateway (token exchange fails — no real OIDC provider).
804        // A 400 would mean the PKCE state was not found, which would be a regression.
805        let callback_uri = format!("/auth/callback?code=test_code&state={state_token}");
806        let req2 = Request::builder().uri(&callback_uri).body(Body::empty()).unwrap();
807        let resp2 = app.clone().oneshot(req2).await.unwrap();
808
809        assert_ne!(
810            resp2.status(),
811            StatusCode::BAD_REQUEST,
812            "state from /auth/start must be accepted by /auth/callback; \
813             400 means the PKCE state was not found or decryption failed",
814        );
815        assert_eq!(
816            resp2.status(),
817            StatusCode::BAD_GATEWAY,
818            "token exchange should fail 502 (no real OIDC provider); got {}",
819            resp2.status(),
820        );
821    }
822}