use ruma::{
api::client::discovery::get_authorization_server_metadata::v1::AuthorizationServerMetadata,
serde::Raw,
};
use serde_json::json;
use url::Url;
use wiremock::{
Mock, MockBuilder, ResponseTemplate,
matchers::{method, path_regex},
};
use super::{MatrixMock, MatrixMockServer, MockEndpoint};
pub struct OAuthMockServer<'a> {
server: &'a MatrixMockServer,
}
impl<'a> OAuthMockServer<'a> {
pub(super) fn new(server: &'a MatrixMockServer) -> Self {
Self { server }
}
fn mock_endpoint<T>(&self, mock: MockBuilder, endpoint: T) -> MockEndpoint<'a, T> {
self.server.mock_endpoint(mock, endpoint)
}
pub fn server_metadata(&self) -> AuthorizationServerMetadata {
MockServerMetadataBuilder::new(&self.server.uri())
.build()
.deserialize()
.expect("mock OAuth 2.0 server metadata should deserialize successfully")
}
}
impl OAuthMockServer<'_> {
pub fn mock_server_metadata(&self) -> MockEndpoint<'_, ServerMetadataEndpoint> {
let mock = Mock::given(method("GET"))
.and(path_regex(r"^/_matrix/client/unstable/org.matrix.msc2965/auth_metadata"));
self.mock_endpoint(mock, ServerMetadataEndpoint)
}
pub fn mock_registration(&self) -> MockEndpoint<'_, RegistrationEndpoint> {
let mock = Mock::given(method("POST")).and(path_regex(r"^/oauth2/registration"));
self.mock_endpoint(mock, RegistrationEndpoint)
}
pub fn mock_device_authorization(&self) -> MockEndpoint<'_, DeviceAuthorizationEndpoint> {
let mock = Mock::given(method("POST")).and(path_regex(r"^/oauth2/device"));
self.mock_endpoint(mock, DeviceAuthorizationEndpoint)
}
pub fn mock_token(&self) -> MockEndpoint<'_, TokenEndpoint> {
let mock = Mock::given(method("POST")).and(path_regex(r"^/oauth2/token"));
self.mock_endpoint(mock, TokenEndpoint)
}
pub fn mock_revocation(&self) -> MockEndpoint<'_, RevocationEndpoint> {
let mock = Mock::given(method("POST")).and(path_regex(r"^/oauth2/revoke"));
self.mock_endpoint(mock, RevocationEndpoint)
}
}
pub struct ServerMetadataEndpoint;
impl<'a> MockEndpoint<'a, ServerMetadataEndpoint> {
pub fn ok(self) -> MatrixMock<'a> {
let metadata = MockServerMetadataBuilder::new(&self.server.uri()).build();
self.respond_with(ResponseTemplate::new(200).set_body_json(metadata))
}
pub fn ok_https(self) -> MatrixMock<'a> {
let issuer = self.server.uri().replace("http://", "https://");
let metadata = MockServerMetadataBuilder::new(&issuer).build();
self.respond_with(ResponseTemplate::new(200).set_body_json(metadata))
}
pub fn ok_without_device_authorization(self) -> MatrixMock<'a> {
let metadata = MockServerMetadataBuilder::new(&self.server.uri())
.without_device_authorization()
.build();
self.respond_with(ResponseTemplate::new(200).set_body_json(metadata))
}
pub fn ok_without_registration(self) -> MatrixMock<'a> {
let metadata =
MockServerMetadataBuilder::new(&self.server.uri()).without_registration().build();
self.respond_with(ResponseTemplate::new(200).set_body_json(metadata))
}
}
#[derive(Debug, Clone)]
pub struct MockServerMetadataBuilder {
issuer: Url,
with_device_authorization: bool,
with_registration: bool,
}
impl MockServerMetadataBuilder {
pub fn new(issuer: &str) -> Self {
let issuer = Url::parse(issuer).expect("We should be able to parse the issuer");
Self { issuer, with_device_authorization: true, with_registration: true }
}
fn without_device_authorization(mut self) -> Self {
self.with_device_authorization = false;
self
}
fn without_registration(mut self) -> Self {
self.with_registration = false;
self
}
fn authorization_endpoint(&self) -> Url {
self.issuer.join("oauth2/authorize").unwrap()
}
fn token_endpoint(&self) -> Url {
self.issuer.join("oauth2/token").unwrap()
}
fn jwks_uri(&self) -> Url {
self.issuer.join("oauth2/keys.json").unwrap()
}
fn registration_endpoint(&self) -> Url {
self.issuer.join("oauth2/registration").unwrap()
}
fn account_management_uri(&self) -> Url {
self.issuer.join("account").unwrap()
}
fn device_authorization_endpoint(&self) -> Url {
self.issuer.join("oauth2/device").unwrap()
}
fn revocation_endpoint(&self) -> Url {
self.issuer.join("oauth2/revoke").unwrap()
}
pub fn build(&self) -> Raw<AuthorizationServerMetadata> {
let mut json_metadata = json!({
"issuer": self.issuer,
"authorization_endpoint": self.authorization_endpoint(),
"token_endpoint": self.token_endpoint(),
"response_types_supported": ["code"],
"response_modes_supported": ["query", "fragment"],
"grant_types_supported": ["authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:device_code"],
"revocation_endpoint": self.revocation_endpoint(),
"code_challenge_methods_supported": ["S256"],
"account_management_uri": self.account_management_uri(),
"account_management_actions_supported": ["org.matrix.profile", "org.matrix.sessions_list", "org.matrix.session_view", "org.matrix.session_end", "org.matrix.deactivateaccount", "org.matrix.cross_signing_reset"],
"prompt_values_supported": ["create"],
});
let json_metadata_object = json_metadata.as_object_mut().unwrap();
if self.with_device_authorization {
json_metadata_object.insert(
"device_authorization_endpoint".to_owned(),
self.device_authorization_endpoint().as_str().into(),
);
}
if self.with_registration {
json_metadata_object.insert(
"registration_endpoint".to_owned(),
self.registration_endpoint().as_str().into(),
);
}
serde_json::from_value(json_metadata).unwrap()
}
}
pub struct RegistrationEndpoint;
impl<'a> MockEndpoint<'a, RegistrationEndpoint> {
pub fn ok(self) -> MatrixMock<'a> {
self.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"client_id": "test_client_id",
"client_id_issued_at": 1716375696,
})))
}
}
pub struct DeviceAuthorizationEndpoint;
impl<'a> MockEndpoint<'a, DeviceAuthorizationEndpoint> {
pub fn ok(self) -> MatrixMock<'a> {
let issuer_url = Url::parse(&self.server.uri())
.expect("We should be able to parse the wiremock server URI");
let verification_uri = issuer_url.join("link").unwrap();
let mut verification_uri_complete = issuer_url.join("link").unwrap();
verification_uri_complete.set_query(Some("code=N32YVC"));
self.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"device_code": "N8NAYD9fOhMulpm37mSthx0xSw2p7vdR",
"expires_in": 1200,
"interval": 5,
"user_code": "N32YVC",
"verification_uri": verification_uri,
"verification_uri_complete": verification_uri_complete,
})))
}
}
pub struct TokenEndpoint;
impl<'a> MockEndpoint<'a, TokenEndpoint> {
pub fn ok(self) -> MatrixMock<'a> {
self.ok_with_tokens("1234", "ZYXWV")
}
pub fn ok_with_tokens(self, access_token: &str, refresh_token: &str) -> MatrixMock<'a> {
self.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"access_token": access_token,
"expires_in": 300,
"refresh_token": refresh_token,
"token_type": "Bearer"
})))
}
pub fn access_denied(self) -> MatrixMock<'a> {
self.respond_with(ResponseTemplate::new(400).set_body_json(json!({
"error": "access_denied",
})))
}
pub fn expired_token(self) -> MatrixMock<'a> {
self.respond_with(ResponseTemplate::new(400).set_body_json(json!({
"error": "expired_token",
})))
}
pub fn invalid_grant(self) -> MatrixMock<'a> {
self.respond_with(ResponseTemplate::new(400).set_body_json(json!({
"error": "invalid_grant",
})))
}
}
pub struct RevocationEndpoint;
impl<'a> MockEndpoint<'a, RevocationEndpoint> {
pub fn ok(self) -> MatrixMock<'a> {
self.respond_with(ResponseTemplate::new(200).set_body_json(json!({})))
}
}