#[cfg(feature = "oidc")]
use crate::END_SESSION_ENDPOINT;
use crate::{
error::{StdbAuthCommandError, StdbAuthError},
message::StdbAuthCommandRejectedMessage,
plugin::PendingAuthOperation,
session::{StdbAuthCredentialMaterial, StdbAuthSession},
source::StdbAuthSource,
};
use bevy_ecs::{
message::Messages,
prelude::{Commands, Res, World},
system::{Command, SystemParam},
};
use bevy_tasks::{IoTaskPool, TaskPool};
#[cfg(feature = "oidc")]
use url::Url;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum StdbAuthOperationKind {
Login,
Logout,
Refresh,
Cancel,
}
#[derive(Clone, Debug)]
pub struct StdbLoginOptions {
pub source: StdbAuthSource,
}
impl StdbLoginOptions {
pub fn new(source: StdbAuthSource) -> Self {
Self { source }
}
}
#[derive(Clone, Debug)]
pub struct StdbLogoutOptions {
pub end_provider_session: bool,
pub forget_device: bool,
}
impl Default for StdbLogoutOptions {
fn default() -> Self {
Self {
end_provider_session: true,
forget_device: false,
}
}
}
#[derive(SystemParam)]
pub struct StdbAuthCommands<'w, 's> {
commands: Commands<'w, 's>,
pending_auth: Option<Res<'w, PendingAuthOperation>>,
session: Option<Res<'w, StdbAuthSession>>,
credentials: Option<Res<'w, StdbAuthCredentialMaterial>>,
}
impl StdbAuthCommands<'_, '_> {
pub fn login(&mut self, options: StdbLoginOptions) -> Result<(), StdbAuthCommandError> {
self.ensure_no_visible_pending_operation()?;
self.commands.queue(StartLoginCommand { options });
Ok(())
}
pub fn logout(&mut self, options: StdbLogoutOptions) -> Result<(), StdbAuthCommandError> {
self.ensure_no_visible_pending_operation()?;
if self.session.is_none() {
return Err(StdbAuthCommandError::NoSession);
}
self.commands.queue(StartLogoutCommand { options });
Ok(())
}
pub fn refresh_now(&mut self) -> Result<(), StdbAuthCommandError> {
self.ensure_no_visible_pending_operation()?;
if self.session.is_none() {
return Err(StdbAuthCommandError::NoSession);
}
let can_refresh = self
.credentials
.as_deref()
.is_some_and(StdbAuthCredentialMaterial::has_refresh_token);
if !can_refresh {
return Err(StdbAuthCommandError::MissingRefreshToken);
}
if self
.session
.as_deref()
.and_then(|session| session.client_id.as_ref())
.is_none()
{
return Err(StdbAuthCommandError::MissingClientId);
}
self.commands.queue(StartRefreshCommand);
Ok(())
}
pub fn cancel_pending(&mut self) -> Result<(), StdbAuthCommandError> {
if self.pending_auth.is_none() {
return Err(StdbAuthCommandError::NoPendingOperation);
}
self.commands.queue(CancelPendingAuthCommand);
Ok(())
}
fn ensure_no_visible_pending_operation(&self) -> Result<(), StdbAuthCommandError> {
if self.pending_auth.is_some() {
return Err(StdbAuthCommandError::PendingOperation);
}
Ok(())
}
}
struct StartLoginCommand {
options: StdbLoginOptions,
}
impl Command for StartLoginCommand {
fn apply(self, world: &mut World) {
if reject_if_pending(world, StdbAuthOperationKind::Login) {
return;
}
let source = self.options.source;
let task = IoTaskPool::get_or_init(TaskPool::default)
.spawn(async move { source.acquire_session().await });
world.insert_resource(PendingAuthOperation::Login(task));
}
}
struct StartLogoutCommand {
options: StdbLogoutOptions,
}
impl Command for StartLogoutCommand {
fn apply(self, world: &mut World) {
if reject_if_pending(world, StdbAuthOperationKind::Logout) {
return;
}
let Some(session) = world.get_resource::<StdbAuthSession>().cloned() else {
reject_auth_command(
world,
StdbAuthOperationKind::Logout,
StdbAuthCommandError::NoSession,
);
return;
};
let id_token_hint = world
.get_resource::<StdbAuthCredentialMaterial>()
.and_then(|credentials| credentials.id_token.clone());
let options = self.options;
let task = IoTaskPool::get_or_init(TaskPool::default).spawn(async move {
if options.forget_device {
clear_persisted_credentials_best_effort(&session);
}
if options.end_provider_session {
end_provider_session(&session, id_token_hint.as_deref())?;
}
Ok::<(), StdbAuthError>(())
});
world.insert_resource(PendingAuthOperation::Logout(task));
}
}
struct StartRefreshCommand;
impl Command for StartRefreshCommand {
fn apply(self, world: &mut World) {
if reject_if_pending(world, StdbAuthOperationKind::Refresh) {
return;
}
if !world.contains_resource::<StdbAuthSession>() {
reject_auth_command(
world,
StdbAuthOperationKind::Refresh,
StdbAuthCommandError::NoSession,
);
return;
}
let Some(refresh_token) = world
.get_resource::<StdbAuthCredentialMaterial>()
.and_then(|credentials| credentials.refresh_token.clone())
else {
reject_auth_command(
world,
StdbAuthOperationKind::Refresh,
StdbAuthCommandError::MissingRefreshToken,
);
return;
};
let session = world.resource::<StdbAuthSession>().clone();
if session.client_id.is_none() {
reject_auth_command(
world,
StdbAuthOperationKind::Refresh,
StdbAuthCommandError::MissingClientId,
);
return;
}
let task = crate::refresh::spawn_refresh_session_task(session, refresh_token);
world.insert_resource(PendingAuthOperation::Refresh {
task,
automatic: false,
});
}
}
struct CancelPendingAuthCommand;
impl Command for CancelPendingAuthCommand {
fn apply(self, world: &mut World) {
if world.remove_resource::<PendingAuthOperation>().is_none() {
reject_auth_command(
world,
StdbAuthOperationKind::Cancel,
StdbAuthCommandError::NoPendingOperation,
);
}
}
}
#[cfg(feature = "oidc")]
fn end_provider_session(
session: &StdbAuthSession,
id_token_hint: Option<&str>,
) -> Result<(), StdbAuthError> {
if session.source != crate::session::StdbAuthSessionSource::Oidc {
return Ok(());
}
let end_session_url = build_end_session_url(session, id_token_hint);
#[cfg(all(feature = "browser", target_arch = "wasm32"))]
{
web_sys::window()
.ok_or_else(|| StdbAuthError::Internal("browser window is unavailable".to_string()))?
.location()
.assign(end_session_url.as_str())
.map_err(|error| {
StdbAuthError::Internal(format!(
"failed to redirect to SpacetimeAuth logout: {error:?}"
))
})?;
}
#[cfg(not(target_arch = "wasm32"))]
{
webbrowser::open(end_session_url.as_str()).map_err(|error| {
StdbAuthError::Internal(format!("failed to open SpacetimeAuth logout URL: {error}"))
})?;
}
Ok(())
}
#[cfg(not(feature = "oidc"))]
fn end_provider_session(
_session: &StdbAuthSession,
_id_token_hint: Option<&str>,
) -> Result<(), StdbAuthError> {
Ok(())
}
#[cfg(feature = "oidc")]
fn build_end_session_url(session: &StdbAuthSession, id_token_hint: Option<&str>) -> Url {
let mut end_session_url = Url::parse(END_SESSION_ENDPOINT)
.expect("static SpacetimeAuth end-session endpoint must be valid");
let mut params = Vec::new();
if let Some(id_token_hint) = id_token_hint.filter(|token| !token.trim().is_empty()) {
params.push(("id_token_hint", id_token_hint));
}
if let Some(post_logout_redirect_uri) = session
.post_logout_redirect_uri
.as_deref()
.filter(|uri| !uri.trim().is_empty())
{
params.push(("post_logout_redirect_uri", post_logout_redirect_uri));
}
if let Some(client_id) = session
.client_id
.as_deref()
.filter(|client_id| !client_id.trim().is_empty())
{
params.push(("client_id", client_id));
}
if !params.is_empty() {
end_session_url.query_pairs_mut().extend_pairs(params);
}
end_session_url
}
#[cfg(all(feature = "oidc", feature = "persistence", not(target_arch = "wasm32")))]
fn clear_persisted_credentials_best_effort(session: &StdbAuthSession) {
if session.source == crate::session::StdbAuthSessionSource::Oidc
&& let Some(client_id) = session.client_id.as_deref()
{
crate::oidc::persistence::clear_refresh_token_best_effort(client_id);
}
}
#[cfg(not(all(feature = "oidc", feature = "persistence", not(target_arch = "wasm32"))))]
fn clear_persisted_credentials_best_effort(_session: &StdbAuthSession) {}
fn reject_if_pending(world: &mut World, operation: StdbAuthOperationKind) -> bool {
if world.contains_resource::<PendingAuthOperation>() {
reject_auth_command(world, operation, StdbAuthCommandError::PendingOperation);
return true;
}
false
}
fn reject_auth_command(
world: &mut World,
operation: StdbAuthOperationKind,
error: StdbAuthCommandError,
) {
if let Some(mut messages) = world.get_resource_mut::<Messages<StdbAuthCommandRejectedMessage>>()
{
messages.write(StdbAuthCommandRejectedMessage { operation, error });
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::session::StdbAuthSessionSource;
fn session_with_refresh_credentials() -> StdbAuthSession {
StdbAuthSession {
access_token: "access".to_string(),
token_type: "Bearer".to_string(),
expires_at: None,
can_refresh: true,
scope: None,
client_id: Some("client".to_string()),
source: StdbAuthSessionSource::Oidc,
post_logout_redirect_uri: None,
}
}
fn world_with_rejection_messages() -> World {
let mut world = World::new();
world.init_resource::<Messages<StdbAuthCommandRejectedMessage>>();
world
}
#[test]
fn logout_options_end_provider_session_by_default() {
let options = StdbLogoutOptions::default();
assert!(options.end_provider_session);
assert!(!options.forget_device);
}
#[cfg(feature = "oidc")]
#[test]
fn end_session_url_contains_logout_context() {
let mut session = session_with_refresh_credentials();
session.post_logout_redirect_uri = Some("http://127.0.0.1:3000/logged-out".to_string());
let end_session_url = build_end_session_url(&session, Some("id-token"));
let params = end_session_url
.query_pairs()
.map(|(key, value)| (key.into_owned(), value.into_owned()))
.collect::<std::collections::BTreeMap<_, _>>();
assert_eq!(
end_session_url.as_str().split('?').next(),
Some("https://auth.spacetimedb.com/oidc/session/end")
);
assert_eq!(params.get("client_id").map(String::as_str), Some("client"));
assert_eq!(
params.get("id_token_hint").map(String::as_str),
Some("id-token")
);
assert_eq!(
params.get("post_logout_redirect_uri").map(String::as_str),
Some("http://127.0.0.1:3000/logged-out")
);
}
#[cfg(feature = "oidc")]
#[test]
fn end_session_url_omits_empty_optional_context() {
let mut session = session_with_refresh_credentials();
session.client_id = Some(" ".to_string());
session.post_logout_redirect_uri = Some(" ".to_string());
let end_session_url = build_end_session_url(&session, Some(" "));
assert!(end_session_url.query().is_none());
}
#[test]
fn refresh_admission_rejects_second_same_frame_request() {
let mut world = world_with_rejection_messages();
let task =
IoTaskPool::get_or_init(TaskPool::default).spawn(async { Ok::<(), StdbAuthError>(()) });
world.insert_resource(PendingAuthOperation::Logout(task));
StartRefreshCommand.apply(&mut world);
assert!(world.contains_resource::<PendingAuthOperation>());
let messages = world.resource::<Messages<StdbAuthCommandRejectedMessage>>();
let rejected = messages.iter_current_update_messages().collect::<Vec<_>>();
assert_eq!(rejected.len(), 1);
assert_eq!(rejected[0].operation, StdbAuthOperationKind::Refresh);
assert_eq!(rejected[0].error, StdbAuthCommandError::PendingOperation);
}
#[test]
fn refresh_admission_rejects_missing_credentials() {
let mut world = world_with_rejection_messages();
world.insert_resource(session_with_refresh_credentials());
StartRefreshCommand.apply(&mut world);
let messages = world.resource::<Messages<StdbAuthCommandRejectedMessage>>();
let rejected = messages.iter_current_update_messages().collect::<Vec<_>>();
assert_eq!(rejected.len(), 1);
assert_eq!(rejected[0].operation, StdbAuthOperationKind::Refresh);
assert_eq!(rejected[0].error, StdbAuthCommandError::MissingRefreshToken);
}
#[test]
fn refresh_admission_rejects_missing_client_id() {
let mut world = world_with_rejection_messages();
let mut session = session_with_refresh_credentials();
session.client_id = None;
world.insert_resource(session);
world.insert_resource(StdbAuthCredentialMaterial::new(
Some("refresh".to_string()),
None,
));
StartRefreshCommand.apply(&mut world);
let messages = world.resource::<Messages<StdbAuthCommandRejectedMessage>>();
let rejected = messages.iter_current_update_messages().collect::<Vec<_>>();
assert_eq!(rejected.len(), 1);
assert_eq!(rejected[0].operation, StdbAuthOperationKind::Refresh);
assert_eq!(rejected[0].error, StdbAuthCommandError::MissingClientId);
}
#[test]
fn refresh_admission_rejects_missing_session() {
let mut world = world_with_rejection_messages();
StartRefreshCommand.apply(&mut world);
let messages = world.resource::<Messages<StdbAuthCommandRejectedMessage>>();
let rejected = messages.iter_current_update_messages().collect::<Vec<_>>();
assert_eq!(rejected.len(), 1);
assert_eq!(rejected[0].operation, StdbAuthOperationKind::Refresh);
assert_eq!(rejected[0].error, StdbAuthCommandError::NoSession);
}
}