1use std::sync::Arc;
23
24use axum::{
25 Json,
26 extract::{Query, State},
27 http::{StatusCode, header},
28 response::{IntoResponse, Redirect, Response},
29};
30use serde::{Deserialize, Serialize};
31
32use crate::auth::{OidcServerClient, PkceStateStore};
33
34pub struct AuthPkceState {
36 pub pkce_store: Arc<PkceStateStore>,
38 pub oidc_client: Arc<OidcServerClient>,
40 pub http_client: Arc<reqwest::Client>,
42 pub post_login_redirect_uri: Option<String>,
45}
46
47#[derive(Deserialize)]
53pub struct AuthStartQuery {
54 redirect_uri: String,
59}
60
61#[derive(Deserialize)]
63pub struct AuthCallbackQuery {
64 code: Option<String>,
66 state: Option<String>,
68 error: Option<String>,
70 error_description: Option<String>,
72}
73
74#[derive(Serialize)]
79struct TokenJson {
80 access_token: String,
81 #[serde(skip_serializing_if = "Option::is_none")]
82 id_token: Option<String>,
83 #[serde(skip_serializing_if = "Option::is_none")]
84 expires_in: Option<u64>,
85 token_type: &'static str,
86}
87
88fn auth_error(status: StatusCode, message: &str) -> Response {
93 (status, Json(serde_json::json!({ "error": message }))).into_response()
94}
95
96pub async fn auth_start(
115 State(state): State<Arc<AuthPkceState>>,
116 Query(q): Query<AuthStartQuery>,
117) -> Response {
118 if q.redirect_uri.is_empty() {
119 return auth_error(StatusCode::BAD_REQUEST, "redirect_uri is required");
120 }
121 if q.redirect_uri.len() > 2048 {
124 return auth_error(StatusCode::BAD_REQUEST, "redirect_uri exceeds maximum length");
125 }
126
127 let (outbound_token, verifier) = match state.pkce_store.create_state(&q.redirect_uri).await {
128 Ok(v) => v,
129 Err(e) => {
130 tracing::error!("pkce create_state failed: {e}");
131 return auth_error(
132 StatusCode::INTERNAL_SERVER_ERROR,
133 "authorization flow could not be started",
134 );
135 },
136 };
137
138 let challenge = PkceStateStore::s256_challenge(&verifier);
139 let location = state.oidc_client.authorization_url(&outbound_token, &challenge, "S256");
140
141 Redirect::to(&location).into_response()
142}
143
144#[allow(clippy::cognitive_complexity)] pub async fn auth_callback(
169 State(state): State<Arc<AuthPkceState>>,
170 Query(q): Query<AuthCallbackQuery>,
171) -> Response {
172 if let Some(err) = q.error {
174 let desc = q.error_description.as_deref().unwrap_or("(no description provided)");
175 tracing::warn!(oidc_error = %err, description = %desc, "OIDC provider returned error");
179 let client_message = match err.as_str() {
180 "access_denied" => "Access was denied",
181 "login_required" => "Authentication is required",
182 "invalid_request" | "invalid_scope" => "Invalid authorization request",
183 "server_error" | "temporarily_unavailable" => "Authorization server error",
184 _ => "Authorization failed",
185 };
186 return auth_error(StatusCode::BAD_REQUEST, client_message);
187 }
188
189 let (Some(code), Some(state_token)) = (q.code, q.state) else {
191 return auth_error(StatusCode::BAD_REQUEST, "missing code or state parameter");
192 };
193
194 let pkce = match state.pkce_store.consume_state(&state_token).await {
196 Ok(s) => s,
197 Err(e) => {
198 tracing::debug!(error = %e, "pkce consume_state failed");
201 return auth_error(StatusCode::BAD_REQUEST, &e.to_string());
202 },
203 };
204
205 let tokens = match state
207 .oidc_client
208 .exchange_code(&code, &pkce.verifier, &state.http_client)
209 .await
210 {
211 Ok(t) => t,
212 Err(e) => {
213 tracing::error!("token exchange failed: {e}");
214 return auth_error(StatusCode::BAD_GATEWAY, "token exchange with OIDC provider failed");
215 },
216 };
217
218 if let Some(redirect_uri) = &state.post_login_redirect_uri {
220 let max_age = tokens.expires_in.unwrap_or(300);
233 let token_escaped = tokens.access_token.replace('\\', r"\\").replace('"', r#"\""#);
235 let cookie = format!(
236 r#"__Host-access_token="{token_escaped}"; Path=/; HttpOnly; Secure; SameSite=Strict; Max-Age={max_age}"#,
237 );
238 let mut resp = Redirect::to(redirect_uri).into_response();
239 match cookie.parse() {
240 Ok(value) => {
241 resp.headers_mut().insert(header::SET_COOKIE, value);
242 },
243 Err(e) => {
244 tracing::error!("Failed to parse Set-Cookie header: {e}");
245 return auth_error(
246 StatusCode::INTERNAL_SERVER_ERROR,
247 "session cookie could not be set",
248 );
249 },
250 }
251 resp
252 } else {
253 Json(TokenJson {
255 access_token: tokens.access_token,
256 id_token: tokens.id_token,
257 expires_in: tokens.expires_in,
258 token_type: "Bearer",
259 })
260 .into_response()
261 }
262}
263
264#[derive(Deserialize)]
270pub struct RevokeTokenRequest {
271 pub token: String,
273}
274
275#[derive(Serialize)]
277pub struct RevokeTokenResponse {
278 pub revoked: bool,
280 #[serde(skip_serializing_if = "Option::is_none")]
282 pub expires_at: Option<String>,
283}
284
285pub struct RevocationRouteState {
287 pub revocation_manager: std::sync::Arc<crate::token_revocation::TokenRevocationManager>,
289}
290
291pub async fn revoke_token(
302 State(state): State<std::sync::Arc<RevocationRouteState>>,
303 Json(body): Json<RevokeTokenRequest>,
304) -> Response {
305 #[derive(serde::Deserialize)]
306 struct MinimalClaims {
307 jti: Option<String>,
308 exp: Option<u64>,
309 }
310
311 let claims = match jsonwebtoken::dangerous::insecure_decode::<MinimalClaims>(&body.token) {
313 Ok(data) => data.claims,
314 Err(e) => {
315 return auth_error(StatusCode::BAD_REQUEST, &format!("Invalid token: {e}"));
316 },
317 };
318
319 let jti = match claims.jti {
320 Some(j) if !j.is_empty() => j,
321 _ => {
322 return auth_error(StatusCode::BAD_REQUEST, "Token has no jti claim");
323 },
324 };
325
326 let ttl_secs = claims
328 .exp
329 .and_then(|exp| {
330 let now = chrono::Utc::now().timestamp().cast_unsigned();
331 exp.checked_sub(now)
332 })
333 .unwrap_or(86400);
334
335 if let Err(e) = state.revocation_manager.revoke(&jti, ttl_secs).await {
336 tracing::error!(error = %e, "Failed to revoke token");
337 return auth_error(StatusCode::INTERNAL_SERVER_ERROR, "Failed to revoke token");
338 }
339
340 let expires_at = claims.exp.map(|exp| {
341 chrono::DateTime::from_timestamp(exp.cast_signed(), 0)
342 .map_or_else(|| exp.to_string(), |dt| dt.to_rfc3339())
343 });
344
345 Json(RevokeTokenResponse {
346 revoked: true,
347 expires_at,
348 })
349 .into_response()
350}
351
352#[derive(Deserialize)]
358pub struct RevokeAllRequest {
359 pub sub: String,
361}
362
363#[derive(Serialize)]
365pub struct RevokeAllResponse {
366 pub revoked_count: u64,
368}
369
370pub async fn revoke_all_tokens(
377 State(state): State<std::sync::Arc<RevocationRouteState>>,
378 Json(body): Json<RevokeAllRequest>,
379) -> Response {
380 if body.sub.is_empty() {
381 return auth_error(StatusCode::BAD_REQUEST, "sub is required");
382 }
383
384 match state.revocation_manager.revoke_all_for_user(&body.sub).await {
385 Ok(count) => Json(RevokeAllResponse {
386 revoked_count: count,
387 })
388 .into_response(),
389 Err(e) => {
390 tracing::error!(error = %e, sub = %body.sub, "Failed to revoke tokens for user");
391 auth_error(StatusCode::INTERNAL_SERVER_ERROR, "Failed to revoke tokens")
392 },
393 }
394}
395
396pub struct AuthMeState {
402 pub expose_claims: Vec<String>,
405}
406
407pub async fn auth_me(
433 axum::extract::State(state): axum::extract::State<std::sync::Arc<AuthMeState>>,
434 axum::Extension(auth_user): axum::Extension<crate::middleware::AuthUser>,
435) -> axum::response::Response {
436 use axum::Json;
437 use axum::response::IntoResponse as _;
438
439 let user = &auth_user.0;
440
441 let mut map = serde_json::Map::new();
442 map.insert("sub".to_owned(), serde_json::Value::String(user.user_id.clone()));
443 map.insert("user_id".to_owned(), serde_json::Value::String(user.user_id.clone()));
444 map.insert(
445 "expires_at".to_owned(),
446 serde_json::Value::String(user.expires_at.to_rfc3339()),
447 );
448
449 for claim_name in &state.expose_claims {
450 if let Some(value) = user.extra_claims.get(claim_name) {
451 map.insert(claim_name.clone(), value.clone());
452 }
453 }
454
455 Json(serde_json::Value::Object(map)).into_response()
456}
457
458#[cfg(test)]
463mod tests {
464 #![allow(clippy::unwrap_used)] use axum::{Extension, Router, body::Body, http::Request, routing::get};
467 use chrono::Utc;
468 use tower::ServiceExt as _;
469
470 use super::*;
471 use crate::auth::PkceStateStore;
472 use crate::middleware::AuthUser;
473
474 fn mock_pkce_store() -> Arc<PkceStateStore> {
475 Arc::new(PkceStateStore::new(600, None))
476 }
477
478 fn make_auth_user(
483 user_id: &str,
484 extra: std::collections::HashMap<String, serde_json::Value>,
485 ) -> AuthUser {
486 AuthUser(fraiseql_core::security::AuthenticatedUser {
487 user_id: user_id.to_owned(),
488 scopes: vec![],
489 expires_at: Utc::now() + chrono::Duration::hours(1),
490 extra_claims: extra,
491 })
492 }
493
494 fn make_me_state(expose_claims: Vec<&str>) -> Arc<AuthMeState> {
495 Arc::new(AuthMeState {
496 expose_claims: expose_claims.into_iter().map(str::to_owned).collect(),
497 })
498 }
499
500 #[tokio::test]
501 async fn test_auth_me_always_returns_sub_user_id_expires_at() {
502 let app = Router::new()
503 .route("/auth/me", get(auth_me))
504 .layer(Extension(make_auth_user("user-123", std::collections::HashMap::new())))
505 .with_state(make_me_state(vec![]));
506
507 let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
508 let resp = app.oneshot(req).await.unwrap();
509 assert_eq!(resp.status(), axum::http::StatusCode::OK);
510
511 let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
512 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
513
514 assert_eq!(json["sub"], "user-123");
515 assert_eq!(json["user_id"], "user-123");
516 assert!(json["expires_at"].is_string(), "expires_at must be present");
517 }
518
519 #[tokio::test]
520 async fn test_auth_me_expose_claims_filters_correctly() {
521 let mut extra = std::collections::HashMap::new();
522 extra.insert("email".to_owned(), serde_json::json!("alice@example.com"));
523 extra.insert(
524 "https://myapp.com/role".to_owned(),
525 serde_json::json!("admin"),
526 );
527
528 let app = Router::new()
529 .route("/auth/me", get(auth_me))
530 .layer(Extension(make_auth_user("alice", extra)))
531 .with_state(make_me_state(vec!["email"]));
532
533 let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
534 let resp = app.oneshot(req).await.unwrap();
535 let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
536 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
537
538 assert_eq!(json["email"], "alice@example.com", "listed claim must appear");
539 assert!(
540 json.get("https://myapp.com/role").is_none(),
541 "unlisted claim must be absent"
542 );
543 }
544
545 #[tokio::test]
546 async fn test_auth_me_claim_absent_from_token_silently_omitted() {
547 let app = Router::new()
549 .route("/auth/me", get(auth_me))
550 .layer(Extension(make_auth_user("user-x", std::collections::HashMap::new())))
551 .with_state(make_me_state(vec!["tenant_id"]));
552
553 let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
554 let resp = app.oneshot(req).await.unwrap();
555 let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
556 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
557
558 assert!(json.get("tenant_id").is_none(), "absent claim must not be null-padded");
559 assert_eq!(json["sub"], "user-x");
561 }
562
563 #[tokio::test]
564 async fn test_auth_me_namespaced_claim_in_expose_claims() {
565 let mut extra = std::collections::HashMap::new();
566 extra.insert("https://myapp.com/role".to_owned(), serde_json::json!("editor"));
567
568 let app = Router::new()
569 .route("/auth/me", get(auth_me))
570 .layer(Extension(make_auth_user("user-y", extra)))
571 .with_state(make_me_state(vec!["https://myapp.com/role"]));
572
573 let req = Request::builder().uri("/auth/me").body(Body::empty()).unwrap();
574 let resp = app.oneshot(req).await.unwrap();
575 let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
576 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
577
578 assert_eq!(json["https://myapp.com/role"], "editor");
579 }
580
581 fn mock_oidc_client() -> Arc<OidcServerClient> {
582 Arc::new(OidcServerClient::new(
583 "test-client",
584 "test-secret",
585 "https://api.example.com/auth/callback",
586 "https://provider.example.com/authorize",
587 "https://provider.example.com/token",
588 ))
589 }
590
591 fn auth_router() -> Router {
592 let auth_state = Arc::new(AuthPkceState {
593 pkce_store: mock_pkce_store(),
594 oidc_client: mock_oidc_client(),
595 http_client: Arc::new(reqwest::Client::new()),
596 post_login_redirect_uri: None,
597 });
598 Router::new()
599 .route("/auth/start", get(auth_start))
600 .route("/auth/callback", get(auth_callback))
601 .with_state(auth_state)
602 }
603
604 #[tokio::test]
605 async fn test_auth_start_redirects_with_pkce_params() {
606 let app = auth_router();
607 let req = Request::builder()
608 .uri("/auth/start?redirect_uri=https%3A%2F%2Fapp.example.com%2Fcb")
609 .body(Body::empty())
610 .unwrap();
611 let resp = app.oneshot(req).await.unwrap();
612
613 assert!(resp.status().is_redirection(), "expected redirect, got {}", resp.status());
615 let location = resp
616 .headers()
617 .get(header::LOCATION)
618 .and_then(|v| v.to_str().ok())
619 .expect("Location header must be present");
620
621 assert!(location.contains("response_type=code"), "missing response_type");
622 assert!(location.contains("code_challenge="), "missing code_challenge");
623 assert!(location.contains("code_challenge_method=S256"), "missing challenge method");
624 assert!(location.contains("state="), "missing state param");
625 assert!(location.contains("client_id=test-client"), "missing client_id");
626 }
627
628 #[tokio::test]
629 async fn test_auth_start_missing_redirect_uri_returns_400() {
630 let app = auth_router();
631 let req = Request::builder().uri("/auth/start").body(Body::empty()).unwrap();
632 let resp = app.oneshot(req).await.unwrap();
633 assert!(
636 resp.status().is_client_error(),
637 "missing redirect_uri must be a client error, got {}",
638 resp.status()
639 );
640 }
641
642 #[tokio::test]
643 async fn test_auth_callback_unknown_state_returns_400() {
644 let app = auth_router();
645 let req = Request::builder()
646 .uri("/auth/callback?code=abc&state=completely-unknown-state")
647 .body(Body::empty())
648 .unwrap();
649 let resp = app.oneshot(req).await.unwrap();
650 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
651 let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
652 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
653 assert!(json["error"].is_string(), "error field must be a string: {json}");
655 }
656
657 #[tokio::test]
658 async fn test_auth_callback_missing_code_returns_400() {
659 let app = auth_router();
660 let req = Request::builder()
661 .uri("/auth/callback?state=some-state-no-code")
662 .body(Body::empty())
663 .unwrap();
664 let resp = app.oneshot(req).await.unwrap();
665 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
666 }
667
668 #[tokio::test]
669 async fn test_auth_start_oversized_redirect_uri_returns_400() {
670 let app = auth_router();
671 let long_uri = "https://example.com/".to_string() + &"a".repeat(2100);
672 let encoded = urlencoding::encode(&long_uri);
673 let req = Request::builder()
674 .uri(format!("/auth/start?redirect_uri={encoded}"))
675 .body(Body::empty())
676 .unwrap();
677 let resp = app.oneshot(req).await.unwrap();
678 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
679 let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
680 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
681 assert!(
682 json["error"].as_str().unwrap_or("").contains("maximum length"),
683 "error must mention length: {json}"
684 );
685 }
686
687 #[tokio::test]
688 async fn test_auth_callback_oidc_error_returns_mapped_message() {
689 let app = auth_router();
690 let req = Request::builder()
692 .uri("/auth/callback?error=access_denied&error_description=internal+tenant+info")
693 .body(Body::empty())
694 .unwrap();
695 let resp = app.oneshot(req).await.unwrap();
696 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
697 let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
698 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
699 let error_msg = json["error"].as_str().unwrap_or("");
700 assert!(
702 !error_msg.contains("internal tenant info"),
703 "provider description must not be reflected to client: {error_msg}"
704 );
705 assert_eq!(error_msg, "Access was denied");
706 }
707
708 #[tokio::test]
709 async fn test_auth_callback_unknown_oidc_error_returns_generic_message() {
710 let app = auth_router();
711 let req = Request::builder()
712 .uri("/auth/callback?error=unknown_vendor_error&error_description=secret+details")
713 .body(Body::empty())
714 .unwrap();
715 let resp = app.oneshot(req).await.unwrap();
716 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
717 let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
718 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
719 assert_eq!(json["error"].as_str().unwrap_or(""), "Authorization failed");
720 }
721
722 #[tokio::test]
723 async fn test_auth_callback_oidc_error_no_description_uses_fallback() {
724 let app = auth_router();
725 let req = Request::builder()
726 .uri("/auth/callback?error=access_denied")
727 .body(Body::empty())
728 .unwrap();
729 let resp = app.oneshot(req).await.unwrap();
730 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
731 let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
735 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
736 assert_eq!(json["error"].as_str().unwrap_or(""), "Access was denied");
737 }
738
739 #[tokio::test]
748 async fn test_auth_start_to_callback_state_roundtrip_with_encryption() {
749 use crate::auth::{EncryptionAlgorithm, StateEncryptionService};
750
751 let enc = Arc::new(StateEncryptionService::from_raw_key(
752 &[0u8; 32],
753 EncryptionAlgorithm::Chacha20Poly1305,
754 ));
755 let pkce_store = Arc::new(PkceStateStore::new(600, Some(enc)));
756
757 let auth_state = Arc::new(AuthPkceState {
758 pkce_store,
759 oidc_client: mock_oidc_client(),
760 http_client: Arc::new(reqwest::Client::new()),
761 post_login_redirect_uri: None,
762 });
763
764 let app = Router::new()
765 .route("/auth/start", get(auth_start))
766 .route("/auth/callback", get(auth_callback))
767 .with_state(auth_state);
768
769 let req = Request::builder()
771 .uri("/auth/start?redirect_uri=https%3A%2F%2Fapp.example.com%2Fcb")
772 .body(Body::empty())
773 .unwrap();
774 let resp = app.clone().oneshot(req).await.unwrap();
775
776 assert!(
777 resp.status().is_redirection(),
778 "expected redirect from /auth/start, got {}",
779 resp.status(),
780 );
781
782 let location = resp
783 .headers()
784 .get(header::LOCATION)
785 .and_then(|v| v.to_str().ok())
786 .expect("Location header must be set")
787 .to_string();
788
789 let parsed_location =
793 reqwest::Url::parse(&location).expect("Location header must be a valid URL");
794 let state_token = parsed_location
795 .query_pairs()
796 .find(|(k, _)| k == "state")
797 .map(|(_, v)| v.into_owned())
798 .expect("state= must appear in the redirect Location URL");
799
800 assert!(!state_token.is_empty(), "extracted state token must not be empty");
801
802 let callback_uri = format!("/auth/callback?code=test_code&state={state_token}");
806 let req2 = Request::builder().uri(&callback_uri).body(Body::empty()).unwrap();
807 let resp2 = app.clone().oneshot(req2).await.unwrap();
808
809 assert_ne!(
810 resp2.status(),
811 StatusCode::BAD_REQUEST,
812 "state from /auth/start must be accepted by /auth/callback; \
813 400 means the PKCE state was not found or decryption failed",
814 );
815 assert_eq!(
816 resp2.status(),
817 StatusCode::BAD_GATEWAY,
818 "token exchange should fail 502 (no real OIDC provider); got {}",
819 resp2.status(),
820 );
821 }
822}