#![cfg_attr(not(target_arch = "wasm32"), deny(clippy::future_not_send))]
#[cfg(feature = "sso-login")]
use std::future::Future;
use ruma::{
api::client::{session::login, uiaa::UserIdentifier},
assign,
};
use tracing::{info, instrument};
use super::Client;
use crate::{config::RequestConfig, Result};
enum LoginMethod<'a> {
UserPassword { id: UserIdentifier<'a>, password: &'a str },
Token(&'a str),
}
impl<'a> LoginMethod<'a> {
fn id(&self) -> Option<&UserIdentifier<'a>> {
match self {
LoginMethod::UserPassword { id, .. } => Some(id),
LoginMethod::Token(_) => None,
}
}
fn tracing_desc(&self) -> &'static str {
match self {
LoginMethod::UserPassword { .. } => "identifier and password",
LoginMethod::Token(_) => "token",
}
}
fn to_login_info(&self) -> login::v3::LoginInfo<'a> {
match self {
LoginMethod::UserPassword { id, password } => {
login::v3::LoginInfo::Password(login::v3::Password::new(id.clone(), password))
}
LoginMethod::Token(token) => login::v3::LoginInfo::Token(login::v3::Token::new(token)),
}
}
}
#[allow(missing_debug_implementations)]
pub struct LoginBuilder<'a> {
client: Client,
login_method: LoginMethod<'a>,
device_id: Option<&'a str>,
initial_device_display_name: Option<&'a str>,
request_refresh_token: bool,
}
impl<'a> LoginBuilder<'a> {
fn new(client: Client, login_method: LoginMethod<'a>) -> Self {
Self {
client,
login_method,
device_id: None,
initial_device_display_name: None,
request_refresh_token: false,
}
}
pub(super) fn new_password(client: Client, id: UserIdentifier<'a>, password: &'a str) -> Self {
Self::new(client, LoginMethod::UserPassword { id, password })
}
pub(super) fn new_token(client: Client, token: &'a str) -> Self {
Self::new(client, LoginMethod::Token(token))
}
pub fn device_id(mut self, value: &'a str) -> Self {
self.device_id = Some(value);
self
}
pub fn initial_device_display_name(mut self, value: &'a str) -> Self {
self.initial_device_display_name = Some(value);
self
}
pub fn request_refresh_token(mut self) -> Self {
self.request_refresh_token = true;
self
}
#[instrument(
target = "matrix_sdk::client",
name = "login",
skip_all,
fields(method = self.login_method.tracing_desc()),
)]
pub async fn send(self) -> Result<login::v3::Response> {
let homeserver = self.client.homeserver().await;
info!(homeserver = homeserver.as_str(), identifier = ?self.login_method.id(), "Logging in");
let request = assign!(login::v3::Request::new(self.login_method.to_login_info()), {
device_id: self.device_id.map(Into::into),
initial_device_display_name: self.initial_device_display_name,
refresh_token: self.request_refresh_token,
});
let response = self.client.send(request, Some(RequestConfig::short_retry())).await?;
self.client.receive_login_response(&response).await?;
Ok(response)
}
}
#[cfg(feature = "sso-login")]
#[allow(missing_debug_implementations)]
pub struct SsoLoginBuilder<'a, F> {
client: Client,
use_sso_login_url: F,
device_id: Option<&'a str>,
initial_device_display_name: Option<&'a str>,
server_url: Option<&'a str>,
server_response: Option<&'a str>,
identity_provider_id: Option<&'a str>,
request_refresh_token: bool,
}
#[cfg(feature = "sso-login")]
impl<'a, F, Fut> SsoLoginBuilder<'a, F>
where
F: FnOnce(String) -> Fut + Send,
Fut: Future<Output = Result<()>> + Send,
{
pub(super) fn new(client: Client, use_sso_login_url: F) -> Self {
Self {
client,
use_sso_login_url,
device_id: None,
initial_device_display_name: None,
server_url: None,
server_response: None,
identity_provider_id: None,
request_refresh_token: false,
}
}
pub fn device_id(mut self, value: &'a str) -> Self {
self.device_id = Some(value);
self
}
pub fn initial_device_display_name(mut self, value: &'a str) -> Self {
self.initial_device_display_name = Some(value);
self
}
pub fn server_url(mut self, value: &'a str) -> Self {
self.server_url = Some(value);
self
}
pub fn server_response(mut self, value: &'a str) -> Self {
self.server_response = Some(value);
self
}
pub fn identity_provider_id(mut self, value: &'a str) -> Self {
self.identity_provider_id = Some(value);
self
}
pub fn request_refresh_token(mut self) -> Self {
self.request_refresh_token = true;
self
}
#[instrument(target = "matrix_sdk::client", name = "login", skip_all, fields(method = "sso"))]
pub async fn send(self) -> Result<login::v3::Response> {
use std::{
collections::HashMap,
io::{Error as IoError, ErrorKind as IoErrorKind},
ops::Range,
sync::{Arc, Mutex},
};
use rand::{thread_rng, Rng};
use tokio::{net::TcpListener, sync::oneshot};
use tokio_stream::wrappers::TcpListenerStream;
use url::Url;
use warp::Filter;
const SSO_SERVER_BIND_RANGE: Range<u16> = 20000..30000;
const SSO_SERVER_BIND_TRIES: u8 = 10;
let homeserver = self.client.homeserver().await;
info!(%homeserver, "Logging in");
let (signal_tx, signal_rx) = oneshot::channel();
let (data_tx, data_rx) = oneshot::channel();
let data_tx_mutex = Arc::new(Mutex::new(Some(data_tx)));
let mut redirect_url = match self.server_url {
Some(s) => Url::parse(s)?,
None => {
Url::parse("http://127.0.0.1:0/").expect("Couldn't parse good known localhost URL")
}
};
let response = self
.server_response
.unwrap_or("The Single Sign-On login process is complete. You can close this page now.")
.to_owned();
let route = warp::get().and(warp::query::<HashMap<String, String>>()).map(
move |p: HashMap<String, String>| {
if let Some(data_tx) = data_tx_mutex.lock().unwrap().take() {
data_tx.send(p.get("loginToken").cloned()).unwrap();
}
http::Response::builder().body(response.clone())
},
);
let listener = {
if redirect_url.port().expect("The redirect URL doesn't include a port") == 0 {
let host = redirect_url.host_str().expect("The redirect URL doesn't have a host");
let mut n = 0u8;
loop {
let port = thread_rng().gen_range(SSO_SERVER_BIND_RANGE);
match TcpListener::bind((host, port)).await {
Ok(l) => {
redirect_url
.set_port(Some(port))
.expect("Could not set new port on redirect URL");
break l;
}
Err(_) if n < SSO_SERVER_BIND_TRIES => {
n += 1;
}
Err(e) => {
return Err(e.into());
}
}
}
} else {
TcpListener::bind(redirect_url.as_str()).await?
}
};
let server = warp::serve(route).serve_incoming_with_graceful_shutdown(
TcpListenerStream::new(listener),
async {
signal_rx.await.ok();
},
);
tokio::spawn(server);
let sso_url =
self.client.get_sso_login_url(redirect_url.as_str(), self.identity_provider_id).await?;
(self.use_sso_login_url)(sso_url).await?;
let token = data_rx
.await
.map_err(|e| IoError::new(IoErrorKind::Other, format!("{e}")))?
.ok_or_else(|| IoError::new(IoErrorKind::Other, "Could not get the loginToken"))?;
let _ = signal_tx.send(());
let login_builder = LoginBuilder {
device_id: self.device_id,
initial_device_display_name: self.initial_device_display_name,
request_refresh_token: self.request_refresh_token,
..LoginBuilder::new_token(self.client, &token)
};
login_builder.send().await
}
}