use std::{fmt, sync::Arc};
use matrix_sdk_base::{SessionMeta, locks::Mutex};
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex as AsyncMutex, OnceCell, broadcast};
pub mod matrix;
pub mod oauth;
use self::{
matrix::MatrixAuth,
oauth::{OAuth, OAuthAuthData, OAuthCtx},
};
use crate::{Client, RefreshTokenError, SessionChange};
#[derive(Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
#[allow(missing_debug_implementations)]
pub struct SessionTokens {
pub access_token: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
}
#[cfg(not(tarpaulin_include))]
impl fmt::Debug for SessionTokens {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SessionTokens").finish_non_exhaustive()
}
}
pub(crate) struct SessionTokensState {
inner: SessionTokens,
access_token_expired: bool,
}
pub(crate) type SessionCallbackError = Box<dyn std::error::Error + Send + Sync>;
#[cfg(not(target_family = "wasm"))]
pub(crate) type SaveSessionCallback =
dyn Fn(Client) -> Result<(), SessionCallbackError> + Send + Sync;
#[cfg(target_family = "wasm")]
pub(crate) type SaveSessionCallback = dyn Fn(Client) -> Result<(), SessionCallbackError>;
#[cfg(not(target_family = "wasm"))]
pub(crate) type ReloadSessionCallback =
dyn Fn(Client) -> Result<SessionTokens, SessionCallbackError> + Send + Sync;
#[cfg(target_family = "wasm")]
pub(crate) type ReloadSessionCallback =
dyn Fn(Client) -> Result<SessionTokens, SessionCallbackError>;
pub(crate) struct AuthCtx {
oauth: OAuthCtx,
pub(crate) handle_refresh_tokens: bool,
refresh_token_lock: Arc<AsyncMutex<Result<(), RefreshTokenError>>>,
pub(crate) session_change_sender: broadcast::Sender<SessionChange>,
pub(crate) auth_data: OnceCell<AuthData>,
tokens: OnceCell<Mutex<SessionTokensState>>,
pub(crate) reload_session_callback: OnceCell<Box<ReloadSessionCallback>>,
pub(crate) save_session_callback: OnceCell<Box<SaveSessionCallback>>,
}
impl AuthCtx {
pub(crate) fn new(handle_refresh_tokens: bool, allow_insecure_oauth: bool) -> Self {
Self {
handle_refresh_tokens,
refresh_token_lock: Arc::new(AsyncMutex::new(Ok(()))),
session_change_sender: broadcast::Sender::new(1),
auth_data: OnceCell::default(),
tokens: OnceCell::default(),
reload_session_callback: OnceCell::default(),
save_session_callback: OnceCell::default(),
oauth: OAuthCtx::new(allow_insecure_oauth),
}
}
pub(crate) fn session_tokens(&self) -> Option<SessionTokens> {
Some(self.tokens.get()?.lock().inner.clone())
}
pub(crate) fn access_token(&self) -> Option<String> {
Some(self.tokens.get()?.lock().inner.access_token.clone())
}
pub(crate) fn has_valid_access_token(&self) -> bool {
self.tokens.get().is_some_and(|tokens| !tokens.lock().access_token_expired)
}
pub(crate) fn set_session_tokens(&self, session_tokens: SessionTokens) {
let session_tokens = SessionTokensState {
inner: session_tokens,
access_token_expired: false,
};
if let Some(tokens) = self.tokens.get() {
*tokens.lock() = session_tokens;
} else {
let _ = self.tokens.set(Mutex::new(session_tokens));
}
}
pub(crate) fn set_access_token_expired(&self, access_token: &str) {
if let Some(tokens) = self.tokens.get() {
let mut tokens = tokens.lock();
if tokens.inner.access_token == access_token {
tokens.access_token_expired = true;
}
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum AuthApi {
Matrix(MatrixAuth),
OAuth(OAuth),
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum AuthSession {
Matrix(matrix::MatrixSession),
OAuth(Box<oauth::OAuthSession>),
}
impl AuthSession {
pub fn meta(&self) -> &SessionMeta {
match self {
AuthSession::Matrix(session) => &session.meta,
AuthSession::OAuth(session) => &session.user.meta,
}
}
pub fn into_meta(self) -> SessionMeta {
match self {
AuthSession::Matrix(session) => session.meta,
AuthSession::OAuth(session) => session.user.meta,
}
}
pub fn access_token(&self) -> &str {
match self {
AuthSession::Matrix(session) => &session.tokens.access_token,
AuthSession::OAuth(session) => &session.user.tokens.access_token,
}
}
pub fn get_refresh_token(&self) -> Option<&str> {
match self {
AuthSession::Matrix(session) => session.tokens.refresh_token.as_deref(),
AuthSession::OAuth(session) => session.user.tokens.refresh_token.as_deref(),
}
}
}
impl From<matrix::MatrixSession> for AuthSession {
fn from(session: matrix::MatrixSession) -> Self {
Self::Matrix(session)
}
}
impl From<oauth::OAuthSession> for AuthSession {
fn from(session: oauth::OAuthSession) -> Self {
Self::OAuth(session.into())
}
}
#[derive(Debug)]
pub(crate) enum AuthData {
Matrix,
OAuth(OAuthAuthData),
}