#![allow(clippy::unwrap_used)] #![allow(clippy::cast_precision_loss)] #![allow(clippy::cast_sign_loss)] #![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_possible_wrap)] #![allow(clippy::cast_lossless)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_errors_doc)] #![allow(missing_docs)] #![allow(clippy::items_after_statements)] #![allow(clippy::used_underscore_binding)] #![allow(clippy::needless_pass_by_value)]
use std::sync::Arc;
use axum::{Router, body::Body, routing::get};
use fraiseql_auth::OidcServerClient;
use fraiseql_server::{
auth::PkceStateStore,
routes::{AuthPkceState, auth_callback, auth_start},
};
use http::{Request, StatusCode};
use tower::ServiceExt;
use wiremock::{
Mock, MockServer, ResponseTemplate,
matchers::{method, path},
};
fn test_oidc_client() -> OidcServerClient {
OidcServerClient::new(
"test-client",
"test-secret",
"http://localhost/auth/callback",
"https://auth.example.com/authorize",
"https://192.0.2.1/token", )
}
fn auth_router() -> Router {
let pkce_store = PkceStateStore::new(300, None);
let oidc_client = test_oidc_client();
let state = Arc::new(AuthPkceState {
pkce_store: Arc::new(pkce_store),
oidc_client: Arc::new(oidc_client),
http_client: Arc::new(reqwest::Client::new()),
post_login_redirect_uri: None,
});
Router::new()
.route("/auth/start", get(auth_start))
.route("/auth/callback", get(auth_callback))
.with_state(state)
}
async fn get_request(router: &Router, uri: &str) -> (StatusCode, Option<String>) {
let response = router
.clone()
.oneshot(Request::builder().uri(uri).body(Body::empty()).unwrap())
.await
.unwrap();
let status = response.status();
let location = response
.headers()
.get("location")
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
(status, location)
}
fn extract_state_param(url: &str) -> &str {
let start = url
.find("state=")
.map(|pos| pos + "state=".len())
.expect("redirect URL must contain state= parameter");
let end = url[start..].find('&').map_or(url.len(), |rel| start + rel);
&url[start..end]
}
#[tokio::test]
async fn auth_start_redirects_to_idp() {
let router = auth_router();
let (status, location) =
get_request(&router, "/auth/start?redirect_uri=https://app.example.com/after-login").await;
assert_eq!(status, StatusCode::SEE_OTHER, "auth_start must redirect (303)");
let loc = location.expect("auth_start must set Location header");
assert!(
loc.contains("auth.example.com"),
"redirect must point to the configured IdP: {loc}"
);
assert!(
loc.contains("code_challenge"),
"redirect must include PKCE code_challenge: {loc}"
);
assert!(loc.contains("state="), "redirect must include opaque state token: {loc}");
}
#[tokio::test]
async fn auth_start_then_callback_completes_pkce_flow() {
let router = auth_router();
let (status, location) =
get_request(&router, "/auth/start?redirect_uri=https://app.example.com/after-login").await;
assert_eq!(status, StatusCode::SEE_OTHER, "auth_start must redirect (303)");
let loc = location.expect("auth_start must provide Location header");
let state_token = extract_state_param(&loc);
assert!(!state_token.is_empty(), "state token must not be empty");
let callback_uri = format!("/auth/callback?code=fake_code&state={state_token}");
let (callback_status, _) = get_request(&router, &callback_uri).await;
assert_ne!(
callback_status,
StatusCode::BAD_REQUEST,
"state token must be valid — failure must come from IdP exchange \
(502), not state lookup (400). Got: {callback_status}"
);
let (replay_status, _) = get_request(&router, &callback_uri).await;
assert_eq!(
replay_status,
StatusCode::BAD_REQUEST,
"second use of the same state token must be rejected (state consumed)"
);
}
#[tokio::test]
async fn auth_start_missing_redirect_uri_returns_400() {
let router = auth_router();
let (status, _) = get_request(&router, "/auth/start").await;
assert_eq!(status, StatusCode::BAD_REQUEST, "missing redirect_uri must return 400");
}
#[tokio::test]
async fn auth_callback_unknown_state_returns_400() {
let router = auth_router();
let (status, _) =
get_request(&router, "/auth/callback?code=any_code&state=unknown-state-token").await;
assert_eq!(status, StatusCode::BAD_REQUEST, "unknown state token must return 400");
}
#[tokio::test]
async fn auth_callback_provider_error_returns_400() {
let router = auth_router();
let (status, _) = get_request(&router, "/auth/callback?error=access_denied").await;
assert_eq!(status, StatusCode::BAD_REQUEST, "provider error must return 400");
}
#[tokio::test]
async fn auth_callback_missing_code_and_state_returns_400() {
let router = auth_router();
let (status, _) = get_request(&router, "/auth/callback").await;
assert_eq!(status, StatusCode::BAD_REQUEST, "callback with no params must return 400");
}
async fn session_cookie_router(mock_server: &MockServer) -> Router {
Mock::given(method("POST"))
.and(path("/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "test-access-token-xyz",
"id_token": "test-id-token",
"expires_in": 3600,
"token_type": "Bearer"
})))
.mount(mock_server)
.await;
let oidc_client = OidcServerClient::new(
"test-client",
"test-secret",
"http://localhost/auth/callback",
"https://auth.example.com/authorize",
format!("{}/token", mock_server.uri()),
);
let pkce_store = PkceStateStore::new(300, None);
let state = Arc::new(AuthPkceState {
pkce_store: Arc::new(pkce_store),
oidc_client: Arc::new(oidc_client),
http_client: Arc::new(reqwest::Client::new()),
post_login_redirect_uri: Some("https://app.example.com/dashboard".to_string()),
});
Router::new()
.route("/auth/start", get(auth_start))
.route("/auth/callback", get(auth_callback))
.with_state(state)
}
#[tokio::test]
async fn auth_callback_session_cookie_mode() {
let mock_server = MockServer::start().await;
let router = session_cookie_router(&mock_server).await;
let (status, location) =
get_request(&router, "/auth/start?redirect_uri=https://app.example.com/after-login").await;
assert_eq!(status, StatusCode::SEE_OTHER, "auth_start must redirect (303)");
let loc = location.expect("auth_start must set Location header");
let state_token = extract_state_param(&loc);
assert!(!state_token.is_empty(), "state token must not be empty");
let callback_uri = format!("/auth/callback?code=valid_code&state={state_token}");
let response = router
.clone()
.oneshot(Request::builder().uri(&callback_uri).body(Body::empty()).unwrap())
.await
.unwrap();
let status = response.status();
assert!(
status == StatusCode::SEE_OTHER || status == StatusCode::FOUND,
"session cookie mode must redirect, got: {status}"
);
let redirect_location = response
.headers()
.get("location")
.and_then(|v| v.to_str().ok())
.expect("redirect response must have Location header");
assert_eq!(
redirect_location, "https://app.example.com/dashboard",
"redirect must point to post_login_redirect_uri, not the caller's redirect_uri"
);
let set_cookie = response
.headers()
.get("set-cookie")
.and_then(|v| v.to_str().ok())
.expect("session cookie mode must set a Set-Cookie header");
assert!(
set_cookie.starts_with("__Host-access_token="),
"cookie must use __Host-access_token prefix, got: {set_cookie}"
);
assert!(
set_cookie.contains("test-access-token-xyz"),
"cookie must contain the access token, got: {set_cookie}"
);
assert!(
set_cookie.contains("HttpOnly"),
"cookie must have HttpOnly attribute, got: {set_cookie}"
);
assert!(
set_cookie.contains("Secure"),
"cookie must have Secure attribute, got: {set_cookie}"
);
assert!(
set_cookie.contains("SameSite=Strict"),
"cookie must have SameSite=Strict, got: {set_cookie}"
);
assert!(set_cookie.contains("Path=/"), "cookie must have Path=/, got: {set_cookie}");
assert!(
set_cookie.contains("Max-Age=3600"),
"cookie Max-Age must match token expires_in (3600), got: {set_cookie}"
);
}
#[cfg(feature = "redis-pkce")]
#[tokio::test]
#[ignore = "requires REDIS_TEST_URL"]
async fn auth_pkce_flow_with_redis_store() {
let redis_url =
std::env::var("REDIS_TEST_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string());
let pkce_store = PkceStateStore::new_redis(&redis_url, 300, None)
.await
.expect("Redis PKCE store must connect");
let oidc_client = test_oidc_client();
let state = Arc::new(AuthPkceState {
pkce_store: Arc::new(pkce_store),
oidc_client: Arc::new(oidc_client),
http_client: Arc::new(reqwest::Client::new()),
post_login_redirect_uri: None,
});
let router = Router::new()
.route("/auth/start", get(auth_start))
.route("/auth/callback", get(auth_callback))
.with_state(state);
let (status, location) =
get_request(&router, "/auth/start?redirect_uri=https://app.example.com/after-login").await;
assert_eq!(status, StatusCode::SEE_OTHER);
let loc = location.unwrap();
let state_token = extract_state_param(&loc);
let callback_uri = format!("/auth/callback?code=fake_code&state={state_token}");
let (callback_status, _) = get_request(&router, &callback_uri).await;
assert_ne!(callback_status, StatusCode::BAD_REQUEST);
let (replay_status, _) = get_request(&router, &callback_uri).await;
assert_eq!(replay_status, StatusCode::BAD_REQUEST);
}