use std::{borrow::Cow, fmt, marker::PhantomData};
use bytes::BufMut;
use ruma_common::{
api::{EndpointError, OutgoingResponse, error::IntoHttpError},
serde::StringEnum,
};
use serde::{Deserialize, Deserializer, Serialize, de};
use serde_json::{from_slice as from_json_slice, value::RawValue as RawJsonValue};
use crate::{
PrivOwnedStr,
error::{Error as MatrixError, StandardErrorBody},
};
mod auth_data;
mod auth_params;
pub mod get_uiaa_fallback_page;
pub use self::{auth_data::*, auth_params::*};
#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/doc/string_enum.md"))]
#[derive(Clone, StringEnum)]
#[non_exhaustive]
pub enum AuthType {
#[ruma_enum(rename = "m.login.password")]
Password,
#[ruma_enum(rename = "m.login.recaptcha")]
ReCaptcha,
#[ruma_enum(rename = "m.login.email.identity")]
EmailIdentity,
#[ruma_enum(rename = "m.login.msisdn")]
Msisdn,
#[ruma_enum(rename = "m.login.sso")]
Sso,
#[ruma_enum(rename = "m.login.dummy")]
Dummy,
#[ruma_enum(rename = "m.login.registration_token")]
RegistrationToken,
#[ruma_enum(rename = "m.login.terms")]
Terms,
#[ruma_enum(rename = "m.oauth", alias = "org.matrix.cross_signing_reset")]
OAuth,
#[doc(hidden)]
_Custom(PrivOwnedStr),
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
pub struct UiaaInfo {
pub flows: Vec<AuthFlow>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub completed: Vec<AuthType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Box<RawJsonValue>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session: Option<String>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
pub auth_error: Option<StandardErrorBody>,
}
impl UiaaInfo {
pub fn new(flows: Vec<AuthFlow>) -> Self {
Self { flows, completed: Vec::new(), params: None, session: None, auth_error: None }
}
pub fn params<'a, T: Deserialize<'a>>(
&'a self,
auth_type: &AuthType,
) -> Result<Option<T>, serde_json::Error> {
struct AuthTypeVisitor<'b, T> {
auth_type: &'b AuthType,
_phantom: PhantomData<T>,
}
impl<'de, T> de::Visitor<'de> for AuthTypeVisitor<'_, T>
where
T: Deserialize<'de>,
{
type Value = Option<T>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a key-value map")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
let mut params = None;
while let Some(key) = map.next_key::<Cow<'de, str>>()? {
if AuthType::from(key) == *self.auth_type {
params = Some(map.next_value()?);
} else {
map.next_value::<de::IgnoredAny>()?;
}
}
Ok(params)
}
}
let Some(params) = &self.params else {
return Ok(None);
};
let mut deserializer = serde_json::Deserializer::from_str(params.get());
deserializer.deserialize_map(AuthTypeVisitor { auth_type, _phantom: PhantomData })
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
pub struct AuthFlow {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stages: Vec<AuthType>,
}
impl AuthFlow {
pub fn new(stages: Vec<AuthType>) -> Self {
Self { stages }
}
}
#[derive(Clone, Debug)]
#[allow(clippy::exhaustive_enums)]
pub enum UiaaResponse {
AuthResponse(UiaaInfo),
MatrixError(MatrixError),
}
impl fmt::Display for UiaaResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::AuthResponse(_) => write!(f, "User-Interactive Authentication required."),
Self::MatrixError(err) => write!(f, "{err}"),
}
}
}
impl From<MatrixError> for UiaaResponse {
fn from(error: MatrixError) -> Self {
Self::MatrixError(error)
}
}
impl EndpointError for UiaaResponse {
fn from_http_response<T: AsRef<[u8]>>(response: http::Response<T>) -> Self {
if response.status() == http::StatusCode::UNAUTHORIZED
&& let Ok(uiaa_info) = from_json_slice(response.body().as_ref())
{
return Self::AuthResponse(uiaa_info);
}
Self::MatrixError(MatrixError::from_http_response(response))
}
}
impl std::error::Error for UiaaResponse {}
impl OutgoingResponse for UiaaResponse {
fn try_into_http_response<T: Default + BufMut>(
self,
) -> Result<http::Response<T>, IntoHttpError> {
match self {
UiaaResponse::AuthResponse(authentication_info) => http::Response::builder()
.header(http::header::CONTENT_TYPE, ruma_common::http_headers::APPLICATION_JSON)
.status(http::StatusCode::UNAUTHORIZED)
.body(ruma_common::serde::json_to_buf(&authentication_info)?)
.map_err(Into::into),
UiaaResponse::MatrixError(error) => error.try_into_http_response(),
}
}
}
#[cfg(test)]
mod tests {
use assert_matches2::assert_matches;
use ruma_common::serde::JsonObject;
use serde_json::{from_value as from_json_value, json};
use super::{AuthType, LoginTermsParams, OAuthParams, UiaaInfo};
#[test]
fn uiaa_info_params() {
let json = json!({
"flows": [{
"stages": ["m.login.terms", "m.login.email.identity", "local.custom.stage"],
}],
"params": {
"local.custom.stage": {
"foo": "bar",
},
"m.login.terms": {
"policies": {
"privacy": {
"en-US": {
"name": "Privacy Policy",
"url": "http://matrix.local/en-US/privacy",
},
"fr-FR": {
"name": "Politique de confidentialité",
"url": "http://matrix.local/fr-FR/privacy",
},
"version": "1",
},
},
}
},
"session": "abcdef",
});
let info = from_json_value::<UiaaInfo>(json).unwrap();
assert_matches!(info.params::<JsonObject>(&AuthType::EmailIdentity), Ok(None));
assert_matches!(
info.params::<JsonObject>(&AuthType::from("local.custom.stage")),
Ok(Some(_))
);
assert_matches!(info.params::<LoginTermsParams>(&AuthType::Terms), Ok(Some(params)));
assert_eq!(params.policies.len(), 1);
let policy = params.policies.get("privacy").unwrap();
assert_eq!(policy.version, "1");
assert_eq!(policy.translations.len(), 2);
let translation = policy.translations.get("en-US").unwrap();
assert_eq!(translation.name, "Privacy Policy");
assert_eq!(translation.url, "http://matrix.local/en-US/privacy");
let translation = policy.translations.get("fr-FR").unwrap();
assert_eq!(translation.name, "Politique de confidentialité");
assert_eq!(translation.url, "http://matrix.local/fr-FR/privacy");
}
#[test]
fn uiaa_info_oauth_params() {
let url = "http://auth.matrix.local/reset";
let stable_json = json!({
"flows": [{
"stages": ["m.oauth"],
}],
"params": {
"m.oauth": {
"url": url,
}
},
"session": "abcdef",
});
let unstable_json = json!({
"flows": [{
"stages": ["org.matrix.cross_signing_reset"],
}],
"params": {
"org.matrix.cross_signing_reset": {
"url": url,
}
},
"session": "abcdef",
});
let info = from_json_value::<UiaaInfo>(stable_json).unwrap();
assert_matches!(info.params::<OAuthParams>(&AuthType::OAuth), Ok(Some(params)));
assert_eq!(params.url, url);
let info = from_json_value::<UiaaInfo>(unstable_json).unwrap();
assert_matches!(info.params::<OAuthParams>(&AuthType::OAuth), Ok(Some(params)));
assert_eq!(params.url, url);
}
}