#[cfg(feature = "sso-login")]
use std::future::Future;
use std::{borrow::Cow, fmt};
use matrix_sdk_base::{SessionMeta, store::RoomLoadSettings};
use ruma::{
api::{
OutgoingRequest,
auth_scheme::SendAccessToken,
client::{
account::register,
session::{
get_login_types, login, logout, refresh_token, sso_login, sso_login_with_provider,
},
uiaa::UserIdentifier,
},
},
serde::JsonObject,
};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::{debug, error, info, instrument};
use url::Url;
use crate::{
Client, Error, RefreshTokenError, Result,
authentication::AuthData,
client::SessionChange,
error::{HttpError, HttpResult},
};
mod login_builder;
pub use self::login_builder::LoginBuilder;
#[cfg(feature = "sso-login")]
pub use self::login_builder::SsoLoginBuilder;
use super::SessionTokens;
#[derive(Debug, Clone)]
pub struct MatrixAuth {
client: Client,
}
#[derive(Debug, Error)]
pub enum SsoError {
#[error("callback URL invalid")]
CallbackUrlInvalid,
}
impl MatrixAuth {
pub(crate) fn new(client: Client) -> Self {
Self { client }
}
pub async fn get_login_types(&self) -> HttpResult<get_login_types::v3::Response> {
let request = get_login_types::v3::Request::new();
self.client.send(request).await
}
pub async fn get_sso_login_url(
&self,
redirect_url: &str,
idp_id: Option<&str>,
) -> Result<String> {
let homeserver = self.client.homeserver();
let supported_versions = self.client.supported_versions().await?;
let request = if let Some(id) = idp_id {
sso_login_with_provider::v3::Request::new(id.to_owned(), redirect_url.to_owned())
.try_into_http_request::<Vec<u8>>(
homeserver.as_str(),
SendAccessToken::None,
Cow::Owned(supported_versions),
)
} else {
sso_login::v3::Request::new(redirect_url.to_owned()).try_into_http_request::<Vec<u8>>(
homeserver.as_str(),
SendAccessToken::None,
Cow::Owned(supported_versions),
)
};
match request {
Ok(req) => Ok(req.uri().to_string()),
Err(err) => Err(Error::from(HttpError::IntoHttp(err))),
}
}
pub fn login_username(&self, id: impl AsRef<str>, password: &str) -> LoginBuilder {
self.login_identifier(UserIdentifier::UserIdOrLocalpart(id.as_ref().to_owned()), password)
}
pub fn login_identifier(&self, id: UserIdentifier, password: &str) -> LoginBuilder {
LoginBuilder::new_password(self.clone(), id, password.to_owned())
}
pub fn login_custom(
&self,
login_type: &str,
data: JsonObject,
) -> serde_json::Result<LoginBuilder> {
LoginBuilder::new_custom(self.clone(), login_type, data)
}
pub fn login_token(&self, token: &str) -> LoginBuilder {
LoginBuilder::new_token(self.clone(), token.to_owned())
}
pub fn login_with_sso_callback(&self, callback_url: Url) -> Result<LoginBuilder, SsoError> {
#[derive(Deserialize)]
struct QueryParameters {
#[serde(rename = "loginToken")]
login_token: Option<String>,
}
let query_string = callback_url.query().unwrap_or_default();
let query: QueryParameters =
serde_html_form::from_str(query_string).map_err(|_| SsoError::CallbackUrlInvalid)?;
let token = query.login_token.ok_or(SsoError::CallbackUrlInvalid)?;
Ok(self.login_token(token.as_str()))
}
#[cfg(feature = "sso-login")]
pub fn login_sso<F, Fut>(&self, use_sso_login_url: F) -> SsoLoginBuilder<F>
where
F: FnOnce(String) -> Fut + Send,
Fut: Future<Output = Result<()>> + Send,
{
SsoLoginBuilder::new(self.clone(), use_sso_login_url)
}
pub fn logged_in(&self) -> bool {
self.client
.auth_ctx()
.auth_data
.get()
.is_some_and(|auth_data| matches!(auth_data, AuthData::Matrix))
}
pub async fn refresh_access_token(&self) -> Result<(), RefreshTokenError> {
macro_rules! fail {
($lock:expr, $err:expr) => {
let error = $err;
*$lock = Err(error.clone());
return Err(error);
};
}
let refresh_token_lock = &self.client.auth_ctx().refresh_token_lock;
let Ok(mut guard) = refresh_token_lock.try_lock() else {
return refresh_token_lock.lock().await.clone();
};
let Some(mut session_tokens) = self.client.session_tokens() else {
fail!(guard, RefreshTokenError::RefreshTokenRequired);
};
let Some(refresh_token) = session_tokens.refresh_token.clone() else {
fail!(guard, RefreshTokenError::RefreshTokenRequired);
};
let request = refresh_token::v3::Request::new(refresh_token);
let res = self.client.send_inner(request, None, Default::default()).await;
match res {
Ok(res) => {
*guard = Ok(());
session_tokens.access_token = res.access_token;
if let Some(refresh_token) = res.refresh_token {
session_tokens.refresh_token = Some(refresh_token);
}
self.client.auth_ctx().set_session_tokens(session_tokens);
if let Some(save_session_callback) =
self.client.inner.auth_ctx.save_session_callback.get()
&& let Err(err) = save_session_callback(self.client.clone())
{
error!("when saving session after refresh: {err}");
}
_ = self
.client
.inner
.auth_ctx
.session_change_sender
.send(SessionChange::TokensRefreshed);
Ok(())
}
Err(error) => {
let error = RefreshTokenError::MatrixAuth(error.into());
fail!(guard, error);
}
}
}
#[instrument(skip_all)]
pub async fn register(&self, request: register::v3::Request) -> Result<register::v3::Response> {
let homeserver = self.client.homeserver();
info!("Registering to {homeserver}");
#[cfg(feature = "e2e-encryption")]
let login_info = match (&request.username, &request.password) {
(Some(u), Some(p)) => Some(login::v3::LoginInfo::Password(login::v3::Password::new(
UserIdentifier::UserIdOrLocalpart(u.into()),
p.clone(),
))),
_ => None,
};
let response = self.client.send(request).await?;
if let Some(session) = MatrixSession::from_register_response(&response) {
let _ = self
.set_session(
session,
RoomLoadSettings::default(),
#[cfg(feature = "e2e-encryption")]
login_info,
)
.await;
}
Ok(response)
}
pub async fn logout(&self) -> HttpResult<logout::v3::Response> {
let request = logout::v3::Request::new();
self.client.send(request).await
}
pub fn session(&self) -> Option<MatrixSession> {
let meta = self.client.session_meta()?;
let tokens = self.client.session_tokens()?;
Some(MatrixSession { meta: meta.to_owned(), tokens })
}
#[instrument(skip_all)]
pub async fn restore_session(
&self,
session: MatrixSession,
room_load_settings: RoomLoadSettings,
) -> Result<()> {
debug!("Restoring Matrix auth session");
self.set_session(
session,
room_load_settings,
#[cfg(feature = "e2e-encryption")]
None,
)
.await?;
debug!("Done restoring Matrix auth session");
Ok(())
}
pub(crate) async fn receive_login_response(
&self,
response: &login::v3::Response,
#[cfg(feature = "e2e-encryption")] login_info: Option<login::v3::LoginInfo>,
) -> Result<()> {
self.client.maybe_update_login_well_known(response.well_known.as_ref());
self.set_session(
response.into(),
RoomLoadSettings::default(),
#[cfg(feature = "e2e-encryption")]
login_info,
)
.await?;
Ok(())
}
async fn set_session(
&self,
session: MatrixSession,
room_load_settings: RoomLoadSettings,
#[cfg(feature = "e2e-encryption")] login_info: Option<login::v3::LoginInfo>,
) -> Result<()> {
self.client
.auth_ctx()
.auth_data
.set(AuthData::Matrix)
.expect("Client authentication data was already set");
self.client.auth_ctx().set_session_tokens(session.tokens);
self.client
.base_client()
.activate(
session.meta,
room_load_settings,
#[cfg(feature = "e2e-encryption")]
None,
)
.await?;
#[cfg(feature = "e2e-encryption")]
{
use ruma::api::client::uiaa::{AuthData, Password};
let auth_data = match login_info {
Some(login::v3::LoginInfo::Password(login::v3::Password {
identifier: Some(identifier),
password,
..
})) => Some(AuthData::Password(Password::new(identifier, password))),
_ => None,
};
self.client.encryption().spawn_initialization_task(auth_data).await;
}
Ok(())
}
}
#[derive(Clone, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub struct MatrixSession {
#[serde(flatten)]
pub meta: SessionMeta,
#[serde(flatten)]
pub tokens: SessionTokens,
}
#[cfg(not(tarpaulin_include))]
impl fmt::Debug for MatrixSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MatrixSession").field("meta", &self.meta).finish_non_exhaustive()
}
}
impl From<&login::v3::Response> for MatrixSession {
fn from(response: &login::v3::Response) -> Self {
let login::v3::Response { user_id, access_token, device_id, refresh_token, .. } = response;
Self {
meta: SessionMeta { user_id: user_id.clone(), device_id: device_id.clone() },
tokens: SessionTokens {
access_token: access_token.clone(),
refresh_token: refresh_token.clone(),
},
}
}
}
impl MatrixSession {
#[allow(clippy::question_mark)] fn from_register_response(response: ®ister::v3::Response) -> Option<Self> {
let register::v3::Response { user_id, access_token, device_id, refresh_token, .. } =
response;
Some(Self {
meta: SessionMeta { user_id: user_id.clone(), device_id: device_id.clone()? },
tokens: SessionTokens {
access_token: access_token.clone()?,
refresh_token: refresh_token.clone(),
},
})
}
}