use axum::Router;
use axum::extract::{Form, Query, State};
use axum::http::{HeaderMap, StatusCode, header::USER_AGENT};
use axum::response::{IntoResponse, Redirect, Response};
use axum::routing::{get, post};
use axum_extra::extract::PrivateCookieJar;
use serde::Deserialize;
use super::cookies;
use super::state::AuthState;
use super::traits::{AccountResolver, SessionStore};
use super::types::NewSession;
use crate::types::PpnumId;
const ROUTE_LOGIN: &str = "/login";
const ROUTE_CALLBACK: &str = "/callback";
const ROUTE_LOGOUT: &str = "/logout";
const ROUTE_DEV_LOGIN: &str = "/dev-login";
const HEADER_X_FORWARDED_FOR: &str = "x-forwarded-for";
const HEADER_X_REAL_IP: &str = "x-real-ip";
const DEFAULT_DEV_PPNUM: &str = "77700000001";
pub(super) fn build_router<U, S>(state: AuthState<U, S>) -> Router
where
U: AccountResolver,
S: SessionStore,
{
let auth_path = state.settings.auth_path.clone();
let mut router = Router::new()
.route(&format!("{auth_path}{ROUTE_LOGIN}"), get(login::<U, S>))
.route(
&format!("{auth_path}{ROUTE_CALLBACK}"),
get(callback_query::<U, S>).post(callback_form::<U, S>),
)
.route(&format!("{auth_path}{ROUTE_LOGOUT}"), post(logout::<U, S>));
if state.settings.dev_login_enabled {
router = router.route(
&format!("{auth_path}{ROUTE_DEV_LOGIN}"),
get(dev_login::<U, S>),
);
}
router.with_state(state)
}
async fn login<U: AccountResolver, S: SessionStore>(
State(state): State<AuthState<U, S>>,
jar: PrivateCookieJar,
) -> Result<(PrivateCookieJar, Redirect), Response> {
let auth_req = state.client.authorization_url();
let (pkce_cookie, state_cookie) = cookies::pkce_cookies(
&auth_req.code_verifier,
&auth_req.state,
state.settings.secure_cookies,
&state.settings.auth_path,
);
let jar = jar.add(pkce_cookie).add(state_cookie);
Ok((jar, Redirect::to(&auth_req.url)))
}
#[derive(Deserialize)]
struct CallbackParams {
code: Option<String>,
state: Option<String>,
error: Option<String>,
error_description: Option<String>,
}
async fn callback_query<U: AccountResolver, S: SessionStore>(
State(state): State<AuthState<U, S>>,
jar: PrivateCookieJar,
Query(params): Query<CallbackParams>,
headers: HeaderMap,
) -> (PrivateCookieJar, Response) {
let xff_skip = state.settings.xff_trusted_proxies;
process_callback(state, jar, params, headers, xff_skip).await
}
async fn callback_form<U: AccountResolver, S: SessionStore>(
State(state): State<AuthState<U, S>>,
jar: PrivateCookieJar,
headers: HeaderMap,
Form(params): Form<CallbackParams>,
) -> (PrivateCookieJar, Response) {
let xff_skip = state.settings.xff_trusted_proxies;
process_callback(state, jar, params, headers, xff_skip).await
}
async fn process_callback<U: AccountResolver, S: SessionStore>(
state: AuthState<U, S>,
jar: PrivateCookieJar,
params: CallbackParams,
headers: HeaderMap,
xff_skip: usize,
) -> (PrivateCookieJar, Response) {
let fail = |jar: PrivateCookieJar, code: &str| -> (PrivateCookieJar, Response) {
let (clear_pkce, clear_state) = cookies::clear_pkce_cookies(&state.settings.auth_path);
let jar = jar.add(clear_pkce).add(clear_state);
(jar, error_redirect(&state.settings.error_redirect, code))
};
if let Some(error) = ¶ms.error {
let desc = params
.error_description
.as_deref()
.unwrap_or("Unknown error");
tracing::warn!(error = %error, description = %desc, "OAuth2 error from PAS");
return fail(jar, desc);
}
let Some(code) = params.code else {
return fail(jar, "missing_code");
};
let Some(received_state) = params.state else {
return fail(jar, "state_mismatch");
};
let Some(stored_state) = cookies::get_state(&jar) else {
return fail(jar, "state_mismatch");
};
if received_state != stored_state {
tracing::warn!("OAuth state mismatch");
return fail(jar, "state_mismatch");
}
let Some(code_verifier) = cookies::get_pkce_verifier(&jar) else {
return fail(jar, "missing_verifier");
};
let token_response = match state.client.exchange_code(&code, &code_verifier).await {
Ok(t) => t,
Err(e) => {
tracing::error!(error = %e, "Token exchange failed");
return fail(jar, "token_exchange_failed");
}
};
let user_info = match state
.client
.get_user_info(&token_response.access_token)
.await
{
Ok(u) => u,
Err(e) => {
tracing::error!(error = %e, "Userinfo request failed");
return fail(jar, "userinfo_failed");
}
};
let ppnum_id = user_info.sub;
let user_id = match state
.account_resolver
.resolve(&ppnum_id, &user_info.ppnum)
.await
{
Ok(id) => id,
Err(e) => {
tracing::error!(error = %e, "Account resolution failed");
return fail(jar, "account_resolution_failed");
}
};
let encrypted_refresh = match (
state.settings.refresh_token_cipher.as_ref(),
token_response.refresh_token.as_deref(),
) {
(Some(cipher), Some(rt)) => match cipher.encrypt_to_token(rt) {
Ok(ct) => Some(ct),
Err(e) => {
tracing::error!(error = %e, "refresh_token encryption failed");
return fail(jar, "refresh_token_encryption_failed");
}
},
_ => None,
};
let session = NewSession {
ppnum_id,
user_id,
refresh_token: encrypted_refresh,
user_agent: extract_user_agent(&headers),
ip_address: extract_client_ip(&headers, xff_skip),
user_info,
};
let session_id = match state.session_store.create(session).await {
Ok(id) => id,
Err(e) => {
tracing::error!(error = %e, "Session creation failed");
return fail(jar, "session_failed");
}
};
let session_cookie = cookies::session_cookie(
&state.settings.session_cookie_name,
&session_id.to_string(),
state.settings.session_ttl_days,
state.settings.secure_cookies,
);
let (clear_pkce, clear_state) = cookies::clear_pkce_cookies(&state.settings.auth_path);
let jar = jar.add(session_cookie).add(clear_pkce).add(clear_state);
tracing::info!(session_id = %session_id, "PAS OAuth2 login successful");
(
jar,
Redirect::to(&state.settings.login_redirect).into_response(),
)
}
async fn logout<U: AccountResolver, S: SessionStore>(
State(state): State<AuthState<U, S>>,
jar: PrivateCookieJar,
) -> (PrivateCookieJar, Redirect) {
if let Some(cookie) = jar.get(&state.settings.session_cookie_name) {
let session_id = crate::types::SessionId(cookie.value().to_string());
if let Err(e) = state.session_store.delete(&session_id).await {
tracing::error!(
error = %e,
session_id = %session_id,
"Session deletion failed during logout — session may persist in store"
);
}
}
let clear_cookie = cookies::clear_session_cookie(&state.settings.session_cookie_name);
(
jar.remove(clear_cookie),
Redirect::to(&state.settings.logout_redirect),
)
}
#[derive(Deserialize)]
struct DevLoginParams {
ppnum: Option<String>,
}
async fn dev_login<U: AccountResolver, S: SessionStore>(
State(state): State<AuthState<U, S>>,
jar: PrivateCookieJar,
Query(params): Query<DevLoginParams>,
headers: HeaderMap,
) -> Result<(PrivateCookieJar, Redirect), Response> {
let ppnum_raw = params.ppnum.as_deref().unwrap_or(DEFAULT_DEV_PPNUM);
let ppnum: crate::types::Ppnum = ppnum_raw
.parse()
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid ppnum for dev login").into_response())?;
let ppnum_id: PpnumId = format!("{ppnum_raw:0>26}")
.parse()
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid ppnum for dev login").into_response())?;
let user_info = crate::oauth::UserInfo::new(ppnum_id, ppnum)
.with_email(format!("{ppnum_raw}@dev.local"))
.with_email_verified(true);
let user_id = state
.account_resolver
.resolve(&ppnum_id, &user_info.ppnum)
.await
.map_err(|e| {
tracing::error!(error = %e, "Dev account resolution failed");
(StatusCode::INTERNAL_SERVER_ERROR, "Dev login failed").into_response()
})?;
let xff_skip = state.settings.xff_trusted_proxies;
let session = NewSession {
ppnum_id,
user_id,
refresh_token: None,
user_agent: extract_user_agent(&headers),
ip_address: extract_client_ip(&headers, xff_skip),
user_info,
};
let session_id = state.session_store.create(session).await.map_err(|e| {
tracing::error!(error = %e, "Dev session creation failed");
(StatusCode::INTERNAL_SERVER_ERROR, "Dev login failed").into_response()
})?;
let session_cookie = cookies::session_cookie(
&state.settings.session_cookie_name,
&session_id.to_string(),
state.settings.session_ttl_days,
state.settings.secure_cookies,
);
tracing::info!(session_id = %session_id, "Dev login successful");
Ok((
jar.add(session_cookie),
Redirect::to(&state.settings.login_redirect),
))
}
fn error_redirect(error_redirect: &str, code: &str) -> Response {
let encoded = urlencoding::encode(code);
Redirect::to(&format!("{error_redirect}?error={encoded}")).into_response()
}
fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
headers
.get(USER_AGENT)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
fn extract_client_ip(headers: &HeaderMap, xff_skip: usize) -> Option<String> {
if let Some(xff) = headers.get(HEADER_X_FORWARDED_FOR).and_then(|v| v.to_str().ok()) {
let entries: Vec<&str> = xff.split(',').map(str::trim).filter(|s| !s.is_empty()).collect();
if !entries.is_empty() {
let idx = entries.len().saturating_sub(xff_skip + 1);
return Some(entries[idx].to_string());
}
}
headers
.get(HEADER_X_REAL_IP)
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_string())
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use axum::http::HeaderValue;
fn xff(value: &str) -> HeaderMap {
let mut h = HeaderMap::new();
h.insert(HEADER_X_FORWARDED_FOR, HeaderValue::from_str(value).unwrap());
h
}
#[test]
fn xff_skip_0_returns_rightmost() {
let h = xff("203.0.113.1, 198.51.100.2, 10.0.0.1");
assert_eq!(extract_client_ip(&h, 0).as_deref(), Some("10.0.0.1"));
}
#[test]
fn xff_skip_1_returns_second_from_right() {
let h = xff("203.0.113.1, 198.51.100.2, 10.0.0.1");
assert_eq!(extract_client_ip(&h, 1).as_deref(), Some("198.51.100.2"));
}
#[test]
fn xff_skip_2_returns_third_from_right() {
let h = xff("203.0.113.1, 198.51.100.2, 10.0.0.1");
assert_eq!(extract_client_ip(&h, 2).as_deref(), Some("203.0.113.1"));
}
#[test]
fn xff_skip_clamped_to_leftmost() {
let h = xff("203.0.113.1, 198.51.100.2, 10.0.0.1");
assert_eq!(extract_client_ip(&h, 99).as_deref(), Some("203.0.113.1"));
}
#[test]
fn x_real_ip_fallback_when_no_xff() {
let mut h = HeaderMap::new();
h.insert(HEADER_X_REAL_IP, HeaderValue::from_static("203.0.113.7"));
assert_eq!(extract_client_ip(&h, 0).as_deref(), Some("203.0.113.7"));
}
#[test]
fn empty_headers_yield_none() {
assert_eq!(extract_client_ip(&HeaderMap::new(), 0), None);
}
#[test]
fn whitespace_around_entries_is_trimmed() {
let h = xff(" 203.0.113.1 , 10.0.0.1 ");
assert_eq!(extract_client_ip(&h, 0).as_deref(), Some("10.0.0.1"));
assert_eq!(extract_client_ip(&h, 1).as_deref(), Some("203.0.113.1"));
}
}