use crate::session::id::SessionId;
use crate::session::store::SessionRevoker;
use axess_clock::Clock;
use axess_factors::oauth::OAuthProvider;
use axess_factors::oidc::logout_token::{
IatCheck, MAX_IAT_AGE_SECS, aud_contains, azp_satisfied, check_iat, decode_jwt_payload,
events_contains_logout,
};
use axum::extract::{Form, State};
use axum::http::StatusCode;
use dashmap::DashMap;
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Clone)]
pub struct BackChannelLogoutHandler {
providers_by_issuer: Arc<HashMap<String, ProviderEntry>>,
registry: Arc<dyn SessionRevoker>,
sid_map: SidMap,
seen_jtis: Arc<DashMap<(String, String), chrono::DateTime<chrono::Utc>>>,
clock: Arc<dyn Clock>,
}
const JTI_CACHE_MAX: usize = 16 * 1024;
const JTI_EVICT_BATCH: usize = 128;
pub type SidKey = (String, String);
pub type SidMap = Arc<
DashMap<
SidKey,
(
crate::authn::ids::UserId,
SessionId,
chrono::DateTime<chrono::Utc>,
),
>,
>;
#[derive(Clone)]
struct ProviderEntry {
client_id: String,
name: Arc<str>,
provider: Arc<dyn OAuthProvider>,
}
#[derive(Deserialize)]
pub struct LogoutParams {
pub logout_token: String,
}
#[derive(Debug)]
pub struct LogoutTokenClaims {
pub sub: Option<String>,
pub sid: Option<String>,
pub iss: String,
pub jti: Option<String>,
}
impl BackChannelLogoutHandler {
pub fn new(
providers: &[Arc<dyn OAuthProvider>],
registry: Arc<dyn SessionRevoker>,
sid_map: SidMap,
clock: Arc<dyn Clock>,
) -> Option<Self> {
let mut by_issuer = HashMap::new();
for provider in providers {
if let (Some(issuer), Some(client_id)) = (provider.issuer(), provider.client_id()) {
by_issuer.insert(
issuer.to_string(),
ProviderEntry {
client_id: client_id.to_string(),
name: provider.name().clone(),
provider: provider.clone(),
},
);
}
}
if by_issuer.is_empty() {
return None;
}
Some(Self {
providers_by_issuer: Arc::new(by_issuer),
registry,
sid_map,
seen_jtis: Arc::new(DashMap::new()),
clock,
})
}
pub async fn handle_backchannel_logout(
State(handler): State<BackChannelLogoutHandler>,
Form(params): Form<LogoutParams>,
) -> Result<StatusCode, StatusCode> {
let claims = handler.validate_logout_token(¶ms.logout_token).await?;
if let Some(ref jti) = claims.jti
&& !handler.record_jti(&claims.iss, jti, handler.clock.now())
{
tracing::warn!(
iss = %claims.iss,
jti = %jti,
"back-channel logout: jti replay rejected"
);
return Err(StatusCode::BAD_REQUEST);
}
let provider_name = handler
.providers_by_issuer
.get(&claims.iss)
.map(|p| p.name.as_ref())
.unwrap_or("unknown");
tracing::info!(
provider = %provider_name,
iss = %claims.iss,
sub = ?claims.sub,
sid = ?claims.sid,
"back-channel logout: invalidating session(s)"
);
if let Some(ref sub) = claims.sub {
match crate::authn::ids::UserId::try_new(sub.as_str()) {
Ok(uid) => handler.registry.invalidate_user(&uid).await,
Err(e) => {
tracing::warn!(
iss = %claims.iss,
sub = %sub,
error = %e,
"back-channel logout: provider sub is not a valid UserId; rejecting"
);
return Err(StatusCode::BAD_REQUEST);
}
}
}
if let Some(ref sid) = claims.sid {
let key: SidKey = (claims.iss.clone(), sid.clone());
if let Some((_, (user_id, session_id, _inserted_at))) = handler.sid_map.remove(&key) {
tracing::info!(
iss = %claims.iss,
oidc_sid = %sid,
user_id = %user_id,
"back-channel logout: invalidating session by OIDC sid"
);
handler
.registry
.invalidate_session(&user_id, &session_id)
.await;
} else if claims.sub.is_none() {
tracing::warn!(
iss = %claims.iss,
sid = %sid,
"back-channel logout: sid not found in sid map and no sub to fall back to"
);
}
}
Ok(StatusCode::OK)
}
async fn validate_logout_token(&self, token: &str) -> Result<LogoutTokenClaims, StatusCode> {
let unverified = decode_jwt_payload(token).map_err(|e| {
tracing::warn!(error = %e, "back-channel logout: failed to decode JWT");
StatusCode::BAD_REQUEST
})?;
let iss = unverified
.get("iss")
.and_then(|v| v.as_str())
.ok_or_else(|| {
tracing::warn!("back-channel logout: missing iss claim");
StatusCode::BAD_REQUEST
})?;
let entry = self.providers_by_issuer.get(iss).ok_or_else(|| {
tracing::warn!(iss = %iss, "back-channel logout: unknown issuer");
StatusCode::BAD_REQUEST
})?;
let payload = match entry.provider.verify_logout_jwt(token) {
Ok(p) => p,
Err(axess_factors::oauth::OAuthError::UnknownKid(kid)) => {
tracing::info!(iss = %iss, kid = %kid, "back-channel logout: kid miss; refreshing JWKS");
if let Err(refresh_err) = entry.provider.refresh_jwks().await {
tracing::warn!(iss = %iss, error = %refresh_err, "JWKS refresh failed");
}
entry.provider.verify_logout_jwt(token).map_err(|e| {
tracing::warn!(iss = %iss, error = %e, "back-channel logout: JWT verification failed after JWKS refresh");
StatusCode::BAD_REQUEST
})?
}
Err(e) => {
tracing::warn!(iss = %iss, error = %e, "back-channel logout: JWT verification failed");
return Err(StatusCode::BAD_REQUEST);
}
};
let provider = entry;
if !aud_contains(&payload, &provider.client_id) {
tracing::warn!(iss = %iss, "back-channel logout: aud does not contain client_id");
return Err(StatusCode::BAD_REQUEST);
}
if !azp_satisfied(&payload, &provider.client_id) {
tracing::warn!(
iss = %iss,
"back-channel logout: aud is multi-valued and azp does not match client_id"
);
return Err(StatusCode::BAD_REQUEST);
}
match check_iat(&payload, self.clock.now().timestamp()) {
IatCheck::Ok => {}
IatCheck::Missing => {
tracing::warn!("back-channel logout: missing or invalid iat claim");
return Err(StatusCode::BAD_REQUEST);
}
IatCheck::OutOfRange { iat, now } => {
tracing::warn!(
iat = iat,
now = now,
"back-channel logout: iat too old or in the future"
);
return Err(StatusCode::BAD_REQUEST);
}
}
if !events_contains_logout(&payload) {
tracing::warn!("back-channel logout: events claim missing back-channel logout URI");
return Err(StatusCode::BAD_REQUEST);
}
let sub = payload
.get("sub")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let sid = payload
.get("sid")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
if sub.is_none() && sid.is_none() {
tracing::warn!("back-channel logout: neither sub nor sid present");
return Err(StatusCode::BAD_REQUEST);
}
if payload.get("nonce").is_some() {
tracing::warn!("back-channel logout: logout token must not contain nonce");
return Err(StatusCode::BAD_REQUEST);
}
let jti = payload
.get("jti")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
Ok(LogoutTokenClaims {
sub,
sid,
iss: iss.to_string(),
jti,
})
}
fn record_jti(&self, issuer: &str, jti: &str, now: chrono::DateTime<chrono::Utc>) -> bool {
let cutoff = now - chrono::Duration::seconds(MAX_IAT_AGE_SECS);
let cache = self.seen_jtis.as_ref();
cache.retain(|_, seen_at| *seen_at >= cutoff);
if cache.len() >= JTI_CACHE_MAX {
let mut oldest: Vec<(chrono::DateTime<chrono::Utc>, (String, String))> = cache
.iter()
.map(|e| (*e.value(), e.key().clone()))
.collect();
oldest.sort_by_key(|(ts, _)| *ts);
for (_, evict_key) in oldest.into_iter().take(JTI_EVICT_BATCH) {
cache.remove(&evict_key);
}
}
let key = (issuer.to_string(), jti.to_string());
cache.insert(key, now).is_none()
}
}
#[cfg(test)]
mod tests;