use std::collections::{BTreeSet, HashMap};
pub use language_tags;
use language_tags::LanguageTag;
use matrix_sdk_base::deserialized_responses::PrivOwnedStr;
use oauth2::{AsyncHttpClient, ClientId, HttpClientError, RequestTokenError};
use ruma::{
SecondsSinceUnixEpoch,
api::client::discovery::get_authorization_server_metadata::v1::{GrantType, ResponseType},
serde::{Raw, StringEnum},
};
use serde::{Deserialize, Serialize, ser::SerializeMap};
use url::Url;
use super::{
OAuthHttpClient,
error::OAuthClientRegistrationError,
http_client::{check_http_response_json_content_type, check_http_response_status_code},
};
#[tracing::instrument(skip_all, fields(registration_endpoint))]
pub(super) async fn register_client(
http_client: &OAuthHttpClient,
registration_endpoint: &Url,
client_metadata: &Raw<ClientMetadata>,
) -> Result<ClientRegistrationResponse, OAuthClientRegistrationError> {
tracing::debug!("Registering client...");
let body =
serde_json::to_vec(client_metadata).map_err(OAuthClientRegistrationError::IntoJson)?;
let request = http::Request::post(registration_endpoint.as_str())
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.to_string())
.body(body)
.map_err(|err| RequestTokenError::Request(HttpClientError::Http(err)))?;
let response = http_client.call(request).await.map_err(RequestTokenError::Request)?;
check_http_response_status_code(&response)?;
check_http_response_json_content_type(&response)?;
let response = serde_json::from_slice(&response.into_body())
.map_err(OAuthClientRegistrationError::FromJson)?;
Ok(response)
}
#[derive(Debug, Clone, Deserialize)]
pub struct ClientRegistrationResponse {
pub client_id: ClientId,
pub client_id_issued_at: Option<SecondsSinceUnixEpoch>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(into = "ClientMetadataSerializeHelper")]
pub struct ClientMetadata {
pub application_type: ApplicationType,
pub grant_types: Vec<OAuthGrantType>,
pub client_uri: Localized<Url>,
pub client_name: Option<Localized<String>>,
pub logo_uri: Option<Localized<Url>>,
pub policy_uri: Option<Localized<Url>>,
pub tos_uri: Option<Localized<Url>>,
}
impl ClientMetadata {
pub fn new(
application_type: ApplicationType,
grant_types: Vec<OAuthGrantType>,
client_uri: Localized<Url>,
) -> Self {
Self {
application_type,
grant_types,
client_uri,
client_name: None,
logo_uri: None,
policy_uri: None,
tos_uri: None,
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum OAuthGrantType {
AuthorizationCode {
redirect_uris: Vec<Url>,
},
DeviceCode,
}
#[derive(Clone, StringEnum)]
#[ruma_enum(rename_all = "lowercase")]
#[non_exhaustive]
pub enum ApplicationType {
Web,
Native,
#[doc(hidden)]
_Custom(PrivOwnedStr),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Localized<T> {
non_localized: T,
localized: HashMap<LanguageTag, T>,
}
impl<T> Localized<T> {
pub fn new(non_localized: T, localized: impl IntoIterator<Item = (LanguageTag, T)>) -> Self {
Self { non_localized, localized: localized.into_iter().collect() }
}
pub fn non_localized(&self) -> &T {
&self.non_localized
}
pub fn get(&self, language: Option<&LanguageTag>) -> Option<&T> {
match language {
Some(lang) => self.localized.get(lang),
None => Some(&self.non_localized),
}
}
}
impl<T> From<(T, HashMap<LanguageTag, T>)> for Localized<T> {
fn from(t: (T, HashMap<LanguageTag, T>)) -> Self {
Localized { non_localized: t.0, localized: t.1 }
}
}
#[derive(Serialize)]
struct ClientMetadataSerializeHelper {
#[serde(skip_serializing_if = "Vec::is_empty")]
redirect_uris: Vec<Url>,
token_endpoint_auth_method: &'static str,
grant_types: BTreeSet<GrantType>,
#[serde(skip_serializing_if = "Vec::is_empty")]
response_types: Vec<ResponseType>,
application_type: ApplicationType,
#[serde(flatten)]
localized: ClientMetadataLocalizedFields,
}
impl From<ClientMetadata> for ClientMetadataSerializeHelper {
fn from(value: ClientMetadata) -> Self {
let ClientMetadata {
application_type,
grant_types: oauth_grant_types,
client_uri,
client_name,
logo_uri,
policy_uri,
tos_uri,
} = value;
let mut redirect_uris = None;
let mut response_types = None;
let mut grant_types = BTreeSet::new();
grant_types.insert(GrantType::RefreshToken);
for oauth_grant_type in oauth_grant_types {
match oauth_grant_type {
OAuthGrantType::AuthorizationCode { redirect_uris: uris } => {
redirect_uris = Some(uris);
response_types = Some(vec![ResponseType::Code]);
grant_types.insert(GrantType::AuthorizationCode);
}
OAuthGrantType::DeviceCode => {
grant_types.insert(GrantType::DeviceCode);
}
}
}
ClientMetadataSerializeHelper {
redirect_uris: redirect_uris.unwrap_or_default(),
token_endpoint_auth_method: "none",
grant_types,
response_types: response_types.unwrap_or_default(),
application_type,
localized: ClientMetadataLocalizedFields {
client_uri,
client_name,
logo_uri,
policy_uri,
tos_uri,
},
}
}
}
struct ClientMetadataLocalizedFields {
client_uri: Localized<Url>,
client_name: Option<Localized<String>>,
logo_uri: Option<Localized<Url>>,
policy_uri: Option<Localized<Url>>,
tos_uri: Option<Localized<Url>>,
}
impl Serialize for ClientMetadataLocalizedFields {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
fn serialize_localized_into_map<M: SerializeMap, T: Serialize>(
map: &mut M,
field_name: &str,
value: &Localized<T>,
) -> Result<(), M::Error> {
map.serialize_entry(field_name, &value.non_localized)?;
for (lang, localized) in &value.localized {
map.serialize_entry(&format!("{field_name}#{lang}"), localized)?;
}
Ok(())
}
let mut map = serializer.serialize_map(None)?;
serialize_localized_into_map(&mut map, "client_uri", &self.client_uri)?;
if let Some(client_name) = &self.client_name {
serialize_localized_into_map(&mut map, "client_name", client_name)?;
}
if let Some(logo_uri) = &self.logo_uri {
serialize_localized_into_map(&mut map, "logo_uri", logo_uri)?;
}
if let Some(policy_uri) = &self.policy_uri {
serialize_localized_into_map(&mut map, "policy_uri", policy_uri)?;
}
if let Some(tos_uri) = &self.tos_uri {
serialize_localized_into_map(&mut map, "tos_uri", tos_uri)?;
}
map.end()
}
}
#[cfg(test)]
mod tests {
use language_tags::LanguageTag;
use serde_json::json;
use url::Url;
use super::{ApplicationType, ClientMetadata, Localized, OAuthGrantType};
#[test]
fn test_serialize_minimal_client_metadata() {
let metadata = ClientMetadata::new(
ApplicationType::Native,
vec![OAuthGrantType::AuthorizationCode {
redirect_uris: vec![Url::parse("http://127.0.0.1/").unwrap()],
}],
Localized::new(
Url::parse("https://github.com/matrix-org/matrix-rust-sdk").unwrap(),
[],
),
);
assert_eq!(
serde_json::to_value(metadata).unwrap(),
json!({
"application_type": "native",
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
"redirect_uris": ["http://127.0.0.1/"],
"client_uri": "https://github.com/matrix-org/matrix-rust-sdk",
}),
);
}
#[test]
fn test_serialize_full_client_metadata() {
let lang_fr = LanguageTag::parse("fr").unwrap();
let lang_mas = LanguageTag::parse("mas").unwrap();
let mut metadata = ClientMetadata::new(
ApplicationType::Web,
vec![
OAuthGrantType::AuthorizationCode {
redirect_uris: vec![
Url::parse("http://127.0.0.1/").unwrap(),
Url::parse("http://[::1]/").unwrap(),
],
},
OAuthGrantType::DeviceCode,
],
Localized::new(
Url::parse("https://example.org/matrix-client").unwrap(),
[
(lang_fr.clone(), Url::parse("https://example.org/fr/matrix-client").unwrap()),
(
lang_mas.clone(),
Url::parse("https://example.org/mas/matrix-client").unwrap(),
),
],
),
);
metadata.client_name = Some(Localized::new(
"My Matrix client".to_owned(),
[(lang_fr.clone(), "Mon client Matrix".to_owned())],
));
metadata.logo_uri =
Some(Localized::new(Url::parse("https://example.org/logo.svg").unwrap(), []));
metadata.policy_uri = Some(Localized::new(
Url::parse("https://example.org/policy").unwrap(),
[
(lang_fr.clone(), Url::parse("https://example.org/fr/policy").unwrap()),
(lang_mas.clone(), Url::parse("https://example.org/mas/policy").unwrap()),
],
));
metadata.tos_uri = Some(Localized::new(
Url::parse("https://example.org/tos").unwrap(),
[
(lang_fr, Url::parse("https://example.org/fr/tos").unwrap()),
(lang_mas, Url::parse("https://example.org/mas/tos").unwrap()),
],
));
assert_eq!(
serde_json::to_value(metadata).unwrap(),
json!({
"application_type": "web",
"grant_types": [
"authorization_code",
"refresh_token",
"urn:ietf:params:oauth:grant-type:device_code",
],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
"redirect_uris": ["http://127.0.0.1/", "http://[::1]/"],
"client_uri": "https://example.org/matrix-client",
"client_uri#fr": "https://example.org/fr/matrix-client",
"client_uri#mas": "https://example.org/mas/matrix-client",
"client_name": "My Matrix client",
"client_name#fr": "Mon client Matrix",
"logo_uri": "https://example.org/logo.svg",
"policy_uri": "https://example.org/policy",
"policy_uri#fr": "https://example.org/fr/policy",
"policy_uri#mas": "https://example.org/mas/policy",
"tos_uri": "https://example.org/tos",
"tos_uri#fr": "https://example.org/fr/tos",
"tos_uri#mas": "https://example.org/mas/tos",
}),
);
}
}