use std::sync::Arc;
use axum::Router;
use axum::extract::{Form, Query, State};
use axum::http::{HeaderMap, StatusCode, header::USER_AGENT};
use axum::response::{IntoResponse, Redirect, Response};
use axum::routing::{get, post};
use axum_extra::extract::PrivateCookieJar;
use serde::Deserialize;
use super::config::PasAuthConfig;
use super::cookies;
use super::state::AuthState;
use super::traits::{AccountResolver, SessionStore};
use super::types::NewSession;
use crate::types::PpnumId;
const ROUTE_LOGIN: &str = "/login";
const ROUTE_CALLBACK: &str = "/callback";
const ROUTE_LOGOUT: &str = "/logout";
const ROUTE_DEV_LOGIN: &str = "/dev-login";
const HEADER_X_FORWARDED_FOR: &str = "x-forwarded-for";
const HEADER_X_REAL_IP: &str = "x-real-ip";
const DEFAULT_DEV_PPNUM: &str = "77700000001";
pub fn auth_routes<U, S>(config: PasAuthConfig, account_resolver: U, session_store: S) -> Router
where
U: AccountResolver,
S: SessionStore,
{
let auth_path = config.settings.auth_path.clone();
let state = AuthState {
client: Arc::new(config.client),
account_resolver: Arc::new(account_resolver),
session_store: Arc::new(session_store),
settings: config.settings,
};
let mut router = Router::new()
.route(&format!("{auth_path}{ROUTE_LOGIN}"), get(login::<U, S>))
.route(
&format!("{auth_path}{ROUTE_CALLBACK}"),
get(callback_query::<U, S>).post(callback_form::<U, S>),
)
.route(&format!("{auth_path}{ROUTE_LOGOUT}"), post(logout::<U, S>));
if state.settings.dev_login_enabled {
router = router.route(&format!("{auth_path}{ROUTE_DEV_LOGIN}"), get(dev_login::<U, S>));
}
router.with_state(state)
}
async fn login<U: AccountResolver, S: SessionStore>(
State(state): State<AuthState<U, S>>,
jar: PrivateCookieJar,
) -> Result<(PrivateCookieJar, Redirect), Response> {
let auth_req = state.client.authorization_url();
let (pkce_cookie, state_cookie) = cookies::pkce_cookies(
&auth_req.code_verifier,
&auth_req.state,
state.settings.secure_cookies,
&state.settings.auth_path,
);
let jar = jar.add(pkce_cookie).add(state_cookie);
Ok((jar, Redirect::to(&auth_req.url)))
}
#[derive(Deserialize)]
struct CallbackParams {
code: Option<String>,
state: Option<String>,
error: Option<String>,
error_description: Option<String>,
}
async fn callback_query<U: AccountResolver, S: SessionStore>(
State(state): State<AuthState<U, S>>,
jar: PrivateCookieJar,
Query(params): Query<CallbackParams>,
headers: HeaderMap,
) -> Result<(PrivateCookieJar, Redirect), Response> {
process_callback(state, jar, params, headers).await
}
async fn callback_form<U: AccountResolver, S: SessionStore>(
State(state): State<AuthState<U, S>>,
jar: PrivateCookieJar,
headers: HeaderMap,
Form(params): Form<CallbackParams>,
) -> Result<(PrivateCookieJar, Redirect), Response> {
process_callback(state, jar, params, headers).await
}
async fn process_callback<U: AccountResolver, S: SessionStore>(
state: AuthState<U, S>,
jar: PrivateCookieJar,
params: CallbackParams,
headers: HeaderMap,
) -> Result<(PrivateCookieJar, Redirect), Response> {
if let Some(error) = ¶ms.error {
let desc = params.error_description.as_deref().unwrap_or("Unknown error");
tracing::warn!(error = %error, description = %desc, "OAuth2 error from PAS");
return Err(error_redirect(&state.settings.error_redirect, desc));
}
let code = params
.code
.ok_or_else(|| error_redirect(&state.settings.error_redirect, "missing_code"))?;
let received_state = params
.state
.ok_or_else(|| error_redirect(&state.settings.error_redirect, "state_mismatch"))?;
let stored_state = cookies::get_state(&jar)
.ok_or_else(|| error_redirect(&state.settings.error_redirect, "state_mismatch"))?;
if received_state != stored_state {
tracing::warn!("OAuth state mismatch");
return Err(error_redirect(&state.settings.error_redirect, "state_mismatch"));
}
let code_verifier = cookies::get_pkce_verifier(&jar)
.ok_or_else(|| error_redirect(&state.settings.error_redirect, "missing_verifier"))?;
let token_response = state
.client
.exchange_code(&code, &code_verifier)
.await
.map_err(|e| {
tracing::error!(error = %e, "Token exchange failed");
error_redirect(&state.settings.error_redirect, "token_exchange_failed")
})?;
let user_info = state
.client
.get_user_info(&token_response.access_token)
.await
.map_err(|e| {
tracing::error!(error = %e, "Userinfo request failed");
error_redirect(&state.settings.error_redirect, "userinfo_failed")
})?;
let ppnum_id = user_info.sub;
let user_id = state
.account_resolver
.resolve(&ppnum_id, &user_info.ppnum)
.await
.map_err(|e| {
tracing::error!(error = %e, "Account resolution failed");
error_redirect(&state.settings.error_redirect, "account_resolution_failed")
})?;
let session = NewSession {
ppnum_id,
user_id,
refresh_token: token_response.refresh_token,
user_agent: extract_user_agent(&headers),
ip_address: extract_client_ip(&headers),
user_info,
};
let session_id = state
.session_store
.create(session)
.await
.map_err(|e| {
tracing::error!(error = %e, "Session creation failed");
error_redirect(&state.settings.error_redirect, "session_failed")
})?;
let session_cookie = cookies::session_cookie(
&state.settings.session_cookie_name,
&session_id.to_string(),
state.settings.session_ttl_days,
state.settings.secure_cookies,
);
let (clear_pkce, clear_state) = cookies::clear_pkce_cookies(&state.settings.auth_path);
let jar = jar
.add(session_cookie)
.add(clear_pkce)
.add(clear_state);
tracing::info!(session_id = %session_id, "PAS OAuth2 login successful");
Ok((jar, Redirect::to(&state.settings.login_redirect)))
}
async fn logout<U: AccountResolver, S: SessionStore>(
State(state): State<AuthState<U, S>>,
jar: PrivateCookieJar,
) -> (PrivateCookieJar, Redirect) {
if let Some(cookie) = jar.get(&state.settings.session_cookie_name) {
let session_id = crate::types::SessionId(cookie.value().to_string());
if let Err(e) = state.session_store.delete(&session_id).await {
tracing::error!(
error = %e,
session_id = %session_id,
"Session deletion failed during logout — session may persist in store"
);
}
}
let clear_cookie = cookies::clear_session_cookie(&state.settings.session_cookie_name);
(jar.remove(clear_cookie), Redirect::to(&state.settings.logout_redirect))
}
#[derive(Deserialize)]
struct DevLoginParams {
ppnum: Option<String>,
}
async fn dev_login<U: AccountResolver, S: SessionStore>(
State(state): State<AuthState<U, S>>,
jar: PrivateCookieJar,
Query(params): Query<DevLoginParams>,
headers: HeaderMap,
) -> Result<(PrivateCookieJar, Redirect), Response> {
let ppnum_raw = params.ppnum.as_deref().unwrap_or(DEFAULT_DEV_PPNUM);
let ppnum: crate::types::Ppnum = ppnum_raw.parse().map_err(|_| {
(StatusCode::BAD_REQUEST, "Invalid ppnum for dev login").into_response()
})?;
let ppnum_id: PpnumId = format!("{ppnum_raw:0>26}").parse().map_err(|_| {
(StatusCode::BAD_REQUEST, "Invalid ppnum for dev login").into_response()
})?;
let user_info = crate::oauth::UserInfo::new(ppnum_id, ppnum)
.with_email(format!("{ppnum_raw}@dev.local"))
.with_email_verified(true);
let user_id = state
.account_resolver
.resolve(&ppnum_id, &user_info.ppnum)
.await
.map_err(|e| {
tracing::error!(error = %e, "Dev account resolution failed");
(StatusCode::INTERNAL_SERVER_ERROR, "Dev login failed").into_response()
})?;
let session = NewSession {
ppnum_id,
user_id,
refresh_token: None,
user_agent: extract_user_agent(&headers),
ip_address: extract_client_ip(&headers),
user_info,
};
let session_id = state
.session_store
.create(session)
.await
.map_err(|e| {
tracing::error!(error = %e, "Dev session creation failed");
(StatusCode::INTERNAL_SERVER_ERROR, "Dev login failed").into_response()
})?;
let session_cookie = cookies::session_cookie(
&state.settings.session_cookie_name,
&session_id.to_string(),
state.settings.session_ttl_days,
state.settings.secure_cookies,
);
tracing::info!(session_id = %session_id, "Dev login successful");
Ok((jar.add(session_cookie), Redirect::to(&state.settings.login_redirect)))
}
fn error_redirect(error_redirect: &str, code: &str) -> Response {
let encoded = urlencoding::encode(code);
Redirect::to(&format!("{error_redirect}?error={encoded}")).into_response()
}
fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
headers
.get(USER_AGENT)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
headers
.get(HEADER_X_FORWARDED_FOR)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.rsplit(',').next())
.map(|s| s.trim().to_string())
.or_else(|| {
headers
.get(HEADER_X_REAL_IP)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
})
}