use axum::extract::{Query, State};
use axum::http::StatusCode;
use axum::response::{Html, IntoResponse, Response};
use serde::Deserialize;
use std::collections::HashSet;
use std::sync::Arc;
use super::backchannel_logout::{SidKey, SidMap};
use crate::session::store::SessionRevoker;
#[derive(Deserialize)]
pub struct FrontChannelLogoutParams {
pub sid: Option<String>,
pub iss: Option<String>,
}
#[derive(Clone)]
pub struct FrontChannelLogoutHandler {
sid_map: SidMap,
registry: Arc<dyn SessionRevoker>,
known_issuers: Arc<HashSet<String>>,
}
impl FrontChannelLogoutHandler {
pub fn new(
sid_map: SidMap,
registry: Arc<dyn SessionRevoker>,
known_issuers: HashSet<String>,
) -> Self {
Self {
sid_map,
registry,
known_issuers: Arc::new(known_issuers),
}
}
pub async fn handle_frontchannel_logout(
State(handler): State<FrontChannelLogoutHandler>,
Query(params): Query<FrontChannelLogoutParams>,
) -> Response {
let iss = match ¶ms.iss {
Some(iss) if handler.known_issuers.contains(iss.as_str()) => iss,
Some(iss) => {
tracing::warn!(
iss = %iss,
"front-channel logout: unknown issuer; rejecting"
);
return StatusCode::BAD_REQUEST.into_response();
}
None => {
tracing::warn!("front-channel logout: missing iss parameter; rejecting");
return StatusCode::BAD_REQUEST.into_response();
}
};
let sid = match ¶ms.sid {
Some(sid) => sid,
None => {
tracing::warn!("front-channel logout: missing sid parameter; rejecting");
return StatusCode::BAD_REQUEST.into_response();
}
};
let key: SidKey = (iss.clone(), sid.clone());
if let Some((_, (user_id, session_id, _inserted_at))) = handler.sid_map.remove(&key) {
tracing::info!(
oidc_sid = %sid,
user_id = %user_id,
iss = %iss,
"front-channel logout: invalidating session by OIDC sid"
);
handler
.registry
.invalidate_session(&user_id, &session_id)
.await;
} else {
tracing::debug!(
iss = %iss,
oidc_sid = %sid,
"front-channel logout: sid not found in sid map (session may already be invalidated)"
);
}
(
StatusCode::OK,
[("cache-control", "no-store")],
Html("<!DOCTYPE html><html><body></body></html>"),
)
.into_response()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::session::id::SessionId;
use crate::session::store::SessionRegistryAdapter;
use crate::session::store::{MemorySessionRegistry, SessionRegistry};
use crate::testing::mock_random::MockRng;
use dashmap::DashMap;
fn known_issuers() -> HashSet<String> {
["https://idp.example.com".to_string()]
.into_iter()
.collect()
}
#[tokio::test]
async fn frontchannel_logout_invalidates_by_sid() {
let registry = MemorySessionRegistry::new();
let sid_map: SidMap = Arc::new(DashMap::new());
let rng = MockRng::new(42);
let session_id = SessionId::new(&rng);
let user_id = axess_identity::testing::user("user-1");
registry.register(&user_id, &session_id).await.unwrap();
let key: SidKey = (
"https://idp.example.com".to_string(),
"oidc-sid-123".to_string(),
);
sid_map.insert(key.clone(), (user_id, session_id, chrono::Utc::now()));
assert!(registry.is_valid(&user_id, &session_id).await.unwrap());
let handler = FrontChannelLogoutHandler::new(
sid_map.clone(),
Arc::new(SessionRegistryAdapter(registry.clone())),
known_issuers(),
);
if let Some((_, (uid, sid, _))) = handler.sid_map.remove(&key) {
handler.registry.invalidate_session(&uid, &sid).await;
}
assert!(!registry.is_valid(&user_id, &session_id).await.unwrap());
assert!(!sid_map.contains_key(&key));
}
#[tokio::test]
async fn frontchannel_logout_unknown_sid_is_noop() {
let registry = MemorySessionRegistry::new();
let sid_map: SidMap = Arc::new(DashMap::new());
let handler = FrontChannelLogoutHandler::new(
sid_map,
Arc::new(SessionRegistryAdapter(registry)),
known_issuers(),
);
let nonexistent: SidKey = (
"https://idp.example.com".to_string(),
"nonexistent".to_string(),
);
assert!(handler.sid_map.remove(&nonexistent).is_none());
}
#[tokio::test]
async fn frontchannel_logout_rejects_unknown_issuer() {
let handler = FrontChannelLogoutHandler::new(
Arc::new(DashMap::new()),
Arc::new(SessionRegistryAdapter(MemorySessionRegistry::new())),
known_issuers(),
);
assert!(!handler.known_issuers.contains("https://evil.example.com"));
}
#[tokio::test]
async fn frontchannel_logout_rejects_missing_iss() {
let handler = FrontChannelLogoutHandler::new(
Arc::new(DashMap::new()),
Arc::new(SessionRegistryAdapter(MemorySessionRegistry::new())),
known_issuers(),
);
let params = FrontChannelLogoutParams {
sid: Some("s1".into()),
iss: None,
};
assert!(params.iss.is_none());
assert!(!handler.known_issuers.contains(""));
}
#[tokio::test]
async fn handle_frontchannel_logout_drives_registry_and_gates_issuer() {
use axum::extract::{Query, State};
let registry = MemorySessionRegistry::new();
let sid_map: SidMap = Arc::new(DashMap::new());
let rng = MockRng::new(7);
let session_id = SessionId::new(&rng);
let user_id = axess_identity::testing::user("u1");
registry.register(&user_id, &session_id).await.unwrap();
let key: SidKey = (
"https://idp.example.com".to_string(),
"oidc-sid-XYZ".to_string(),
);
sid_map.insert(key.clone(), (user_id, session_id, chrono::Utc::now()));
let handler = FrontChannelLogoutHandler::new(
sid_map.clone(),
Arc::new(SessionRegistryAdapter(registry.clone())),
known_issuers(),
);
let response = FrontChannelLogoutHandler::handle_frontchannel_logout(
State(handler.clone()),
Query(FrontChannelLogoutParams {
sid: Some("oidc-sid-XYZ".to_string()),
iss: Some("https://idp.example.com".to_string()),
}),
)
.await;
assert_eq!(
response.status(),
StatusCode::OK,
"known iss + known sid must return 200 OK"
);
assert!(
!registry.is_valid(&user_id, &session_id).await.unwrap(),
"front-channel logout must invalidate the registry session; \
a non-default response without side effect is silent failure"
);
let response = FrontChannelLogoutHandler::handle_frontchannel_logout(
State(handler.clone()),
Query(FrontChannelLogoutParams {
sid: Some("anything".to_string()),
iss: Some("https://evil.example.com".to_string()),
}),
)
.await;
assert_eq!(
response.status(),
StatusCode::BAD_REQUEST,
"unknown iss must be rejected as 400"
);
}
}