use anyhow::Context as _;
use assert_matches::assert_matches;
use matrix_sdk_base::store::RoomLoadSettings;
use matrix_sdk_test::async_test;
use oauth2::{ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope};
use ruma::{
DeviceId, ServerName, api::client::discovery::get_authorization_server_metadata::v1::Prompt,
device_id, owned_device_id, user_id,
};
use tokio::sync::broadcast::error::TryRecvError;
use url::Url;
use super::{
AuthorizationCode, AuthorizationError, AuthorizationResponse, OAuth, OAuthAuthorizationData,
OAuthError, RedirectUriQueryParseError, UrlOrQuery,
};
use crate::{
Client, Error, SessionChange,
authentication::oauth::{
AuthorizationValidationData, ClientRegistrationData, OAuthAuthorizationCodeError,
error::{AuthorizationCodeErrorResponseType, OAuthClientRegistrationError},
},
test_utils::{
client::{
MockClientBuilder, mock_prev_session_tokens_with_refresh,
mock_session_tokens_with_refresh,
oauth::{mock_client_id, mock_client_metadata, mock_redirect_uri, mock_session},
},
mocks::MatrixMockServer,
},
};
const REDIRECT_URI_STRING: &str = "http://127.0.0.1:6778/oauth/callback";
async fn mock_environment() -> anyhow::Result<(OAuth, MatrixMockServer, Url, ClientRegistrationData)>
{
let server = MatrixMockServer::new().await;
server.mock_who_am_i().ok().named("whoami").mount().await;
let oauth_server = server.oauth();
oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
oauth_server.mock_token().ok().mount().await;
let client = server.client_builder().unlogged().build().await;
let client_metadata = mock_client_metadata();
Ok((client.oauth(), server, mock_redirect_uri(), client_metadata.into()))
}
async fn check_authorization_url(
authorization_data: &OAuthAuthorizationData,
oauth: &OAuth,
server_uri: &str,
device_id: Option<&DeviceId>,
expected_prompt: Option<&str>,
expected_login_hint: Option<&str>,
additional_scopes: Option<Vec<Scope>>,
) {
tracing::debug!("authorization data URL = {}", authorization_data.url);
let data = oauth.data().unwrap();
let authorization_data_guard = data.authorization_data.lock().await;
let validation_data =
authorization_data_guard.get(&authorization_data.state).expect("missing validation data");
let mut num_expected =
7 + expected_prompt.is_some() as i8 + expected_login_hint.is_some() as i8;
let mut code_challenge = None;
let mut prompt = None;
let mut login_hint = None;
for (key, val) in authorization_data.url.query_pairs() {
match &*key {
"response_type" => {
assert_eq!(val, "code");
num_expected -= 1;
}
"client_id" => {
assert_eq!(val, "test_client_id");
num_expected -= 1;
}
"redirect_uri" => {
assert_eq!(val, validation_data.redirect_uri.as_str());
num_expected -= 1;
}
"scope" => {
let actual_scopes: Vec<String> = val.split(' ').map(String::from).collect();
assert!(actual_scopes.len() >= 2, "Expected at least two scopes");
assert!(
actual_scopes
.contains(&"urn:matrix:org.matrix.msc2967.client:api:*".to_owned()),
"Expected Matrix API scope not found in scopes"
);
if let Some(device_id) = device_id {
let device_id_scope =
format!("urn:matrix:org.matrix.msc2967.client:device:{device_id}");
assert!(
actual_scopes.contains(&device_id_scope),
"Expected device ID scope not found in scopes"
)
} else {
assert!(
actual_scopes
.iter()
.any(|s| s.starts_with("urn:matrix:org.matrix.msc2967.client:device:")),
"Expected device ID scope not found in scopes"
);
}
if let Some(additional_scopes) = &additional_scopes {
let expected_len = 2 + additional_scopes.len();
assert_eq!(actual_scopes.len(), expected_len, "Expected {expected_len} scopes",);
for scope in additional_scopes {
assert!(
actual_scopes.contains(scope),
"Expected additional scope not found in scopes: {scope:?}",
);
}
}
num_expected -= 1;
}
"state" => {
num_expected -= 1;
assert_eq!(val, authorization_data.state.secret().as_str());
}
"code_challenge" => {
code_challenge = Some(val);
num_expected -= 1;
}
"code_challenge_method" => {
assert_eq!(val, "S256");
num_expected -= 1;
}
"prompt" => {
prompt = Some(val);
num_expected -= 1;
}
"login_hint" => {
login_hint = Some(val);
num_expected -= 1;
}
_ => panic!("unexpected query parameter: {key}={val}"),
}
}
assert_eq!(num_expected, 0);
let code_challenge = code_challenge.expect("missing code_challenge");
assert_eq!(
code_challenge,
PkceCodeChallenge::from_code_verifier_sha256(&validation_data.pkce_verifier).as_str()
);
assert_eq!(prompt.as_deref(), expected_prompt);
assert_eq!(login_hint.as_deref(), expected_login_hint);
assert!(authorization_data.url.as_str().starts_with(server_uri));
assert_eq!(authorization_data.url.path(), "/oauth2/authorize");
}
#[async_test]
async fn test_high_level_login() -> anyhow::Result<()> {
let (oauth, _server, mut redirect_uri, registration_data) = mock_environment().await.unwrap();
assert!(oauth.client_id().is_none());
let authorization_data = oauth
.login(redirect_uri.clone(), None, Some(registration_data), None)
.prompt(vec![Prompt::Create])
.build()
.await
.unwrap();
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
redirect_uri.set_query(Some(&format!("code=42&state={}", authorization_data.state.secret())));
oauth.finish_login(redirect_uri.into()).await?;
Ok(())
}
#[async_test]
async fn test_high_level_login_cancellation() -> anyhow::Result<()> {
let (oauth, server, mut redirect_uri, registration_data) = mock_environment().await.unwrap();
let authorization_data = oauth
.login(redirect_uri.clone(), None, Some(registration_data), None)
.build()
.await
.unwrap();
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
check_authorization_url(&authorization_data, &oauth, &server.uri(), None, None, None, None)
.await;
redirect_uri.set_query(Some(&format!(
"error=access_denied&state={}",
authorization_data.state.secret()
)));
let error = oauth.finish_login(redirect_uri.into()).await.unwrap_err();
assert_matches!(
error,
Error::OAuth(error) => {
assert_matches!(*error, OAuthError::AuthorizationCode(OAuthAuthorizationCodeError::Cancelled));
}
);
Ok(())
}
#[async_test]
async fn test_high_level_login_invalid_state() -> anyhow::Result<()> {
let (oauth, server, mut redirect_uri, registration_data) = mock_environment().await.unwrap();
let authorization_data = oauth
.login(redirect_uri.clone(), None, Some(registration_data), None)
.build()
.await
.unwrap();
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
check_authorization_url(&authorization_data, &oauth, &server.uri(), None, None, None, None)
.await;
redirect_uri.set_query(Some("code=42&state=imposter_alert"));
let error = oauth.finish_login(redirect_uri.into()).await.unwrap_err();
assert_matches!(
error,
Error::OAuth(error) => {
assert_matches!(*error, OAuthError::AuthorizationCode(OAuthAuthorizationCodeError::InvalidState));
}
);
Ok(())
}
#[async_test]
async fn test_login_url() -> anyhow::Result<()> {
let server = MatrixMockServer::new().await;
let server_uri = server.uri();
let oauth_server = server.oauth();
oauth_server.mock_server_metadata().ok().expect(4).mount().await;
let client = server.client_builder().registered_with_oauth().build().await;
let oauth = client.oauth();
let device_id = owned_device_id!("D3V1C31D");
let redirect_uri_str = REDIRECT_URI_STRING;
let redirect_uri = Url::parse(redirect_uri_str)?;
let additional_scopes =
vec![Scope::new("urn:test:scope1".to_owned()), Scope::new("urn:test:scope2".to_owned())];
let authorization_data =
oauth.login(redirect_uri.clone(), Some(device_id.clone()), None, None).build().await?;
check_authorization_url(
&authorization_data,
&oauth,
&server_uri,
Some(&device_id),
None,
None,
None,
)
.await;
let authorization_data = oauth
.login(redirect_uri.clone(), Some(device_id.clone()), None, None)
.prompt(vec![Prompt::Create])
.build()
.await?;
check_authorization_url(
&authorization_data,
&oauth,
&server_uri,
Some(&device_id),
Some("create"),
None,
None,
)
.await;
let authorization_data = oauth
.login(redirect_uri.clone(), Some(device_id.clone()), None, None)
.user_id_hint(user_id!("@joe:example.org"))
.build()
.await?;
check_authorization_url(
&authorization_data,
&oauth,
&server_uri,
Some(&device_id),
None,
Some("mxid:@joe:example.org"),
None,
)
.await;
let authorization_data = oauth
.login(redirect_uri.clone(), Some(device_id.clone()), None, Some(additional_scopes.clone()))
.build()
.await?;
check_authorization_url(
&authorization_data,
&oauth,
&server_uri,
Some(&device_id),
None,
None,
Some(additional_scopes),
)
.await;
Ok(())
}
#[test]
fn test_authorization_response() -> anyhow::Result<()> {
let uri = Url::parse("https://example.com")?;
assert_matches!(
AuthorizationResponse::parse_url_or_query(&uri.into()),
Err(RedirectUriQueryParseError::MissingQuery)
);
let uri = Url::parse("https://example.com?code=123&state=456")?;
assert_matches!(
AuthorizationResponse::parse_url_or_query(&uri.into()),
Ok(AuthorizationResponse::Success(AuthorizationCode { code, state })) => {
assert_eq!(code, "123");
assert_eq!(state.secret(), "456");
}
);
let uri = Url::parse("https://example.com?error=invalid_scope&state=456")?;
assert_matches!(
AuthorizationResponse::parse_url_or_query(&uri.into()),
Ok(AuthorizationResponse::Error(AuthorizationError { error, state })) => {
assert_eq!(*error.error(), AuthorizationCodeErrorResponseType::InvalidScope);
assert_eq!(error.error_description(), None);
assert_eq!(state.secret(), "456");
}
);
Ok(())
}
#[async_test]
async fn test_finish_login() -> anyhow::Result<()> {
let server = MatrixMockServer::new().await;
let oauth_server = server.oauth();
let server_metadata = oauth_server.server_metadata();
let client = server.client_builder().registered_with_oauth().build().await;
let oauth = client.oauth();
let res = oauth.finish_login(UrlOrQuery::Query("code=42&state=none".to_owned())).await;
assert_matches!(
res,
Err(Error::OAuth(error)) => {
assert_matches!(*error, OAuthError::AuthorizationCode(OAuthAuthorizationCodeError::InvalidState));
}
);
assert!(client.session_tokens().is_none());
assert!(client.session_meta().is_none());
let state1 = CsrfToken::new("state1".to_owned());
let redirect_uri = REDIRECT_URI_STRING;
let (_pkce_code_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let auth_validation_data = AuthorizationValidationData {
server_metadata: server_metadata.clone(),
device_id: owned_device_id!("D3V1C31D"),
redirect_uri: RedirectUrl::new(redirect_uri.to_owned())?,
pkce_verifier,
};
{
let data = oauth.data().context("missing data")?;
let prev =
data.authorization_data.lock().await.insert(state1.clone(), auth_validation_data);
assert!(prev.is_none());
}
let res = oauth.finish_login(UrlOrQuery::Query("code=1337&state=none".to_owned())).await;
assert_matches!(
res,
Err(Error::OAuth(error)) => {
assert_matches!(*error, OAuthError::AuthorizationCode(OAuthAuthorizationCodeError::InvalidState));
}
);
assert!(client.session_tokens().is_none());
assert!(oauth.data().unwrap().authorization_data.lock().await.get(&state1).is_some());
oauth_server
.mock_token()
.ok_with_tokens("AT1", "RT1")
.mock_once()
.named("token_1")
.mount()
.await;
server
.mock_who_am_i()
.expect_access_token("AT1")
.ok()
.mock_once()
.named("whoami_1")
.mount()
.await;
oauth.finish_login(UrlOrQuery::Query(format!("code=42&state={}", state1.secret()))).await?;
let session_tokens = client.session_tokens().unwrap();
assert_eq!(session_tokens.access_token, "AT1");
assert_eq!(session_tokens.refresh_token.as_deref(), Some("RT1"));
assert!(client.session_meta().is_some());
assert!(oauth.data().unwrap().authorization_data.lock().await.get(&state1).is_none());
let state2 = CsrfToken::new("state2".to_owned());
let redirect_uri = REDIRECT_URI_STRING;
let (_pkce_code_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let auth_validation_data = AuthorizationValidationData {
server_metadata: server_metadata.clone(),
device_id: owned_device_id!("D3V1C31D"),
redirect_uri: RedirectUrl::new(redirect_uri.to_owned())?,
pkce_verifier,
};
{
let data = oauth.data().context("missing data")?;
let prev =
data.authorization_data.lock().await.insert(state2.clone(), auth_validation_data);
assert!(prev.is_none());
}
oauth_server
.mock_token()
.ok_with_tokens("AT2", "RT2")
.mock_once()
.named("token_2")
.mount()
.await;
server
.mock_who_am_i()
.expect_access_token("AT2")
.ok()
.mock_once()
.named("whoami_2")
.mount()
.await;
oauth.finish_login(UrlOrQuery::Query(format!("code=42&state={}", state2.secret()))).await?;
let session_tokens = client.session_tokens().unwrap();
assert_eq!(session_tokens.access_token, "AT2");
assert_eq!(session_tokens.refresh_token.as_deref(), Some("RT2"));
assert!(client.session_meta().is_some());
assert!(oauth.data().unwrap().authorization_data.lock().await.get(&state2).is_none());
let wrong_device_id = device_id!("WR0NG");
let state3 = CsrfToken::new("state3".to_owned());
let redirect_uri = REDIRECT_URI_STRING;
let (_pkce_code_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let auth_validation_data = AuthorizationValidationData {
server_metadata,
device_id: wrong_device_id.to_owned(),
redirect_uri: RedirectUrl::new(redirect_uri.to_owned())?,
pkce_verifier,
};
{
let data = oauth.data().context("missing data")?;
let prev =
data.authorization_data.lock().await.insert(state3.clone(), auth_validation_data);
assert!(prev.is_none());
}
oauth_server
.mock_token()
.ok_with_tokens("AT3", "RT3")
.mock_once()
.named("token_3")
.mount()
.await;
server
.mock_who_am_i()
.expect_access_token("AT3")
.ok_with_device_id(wrong_device_id)
.mock_once()
.named("whoami_3")
.mount()
.await;
let res =
oauth.finish_login(UrlOrQuery::Query(format!("code=42&state={}", state3.secret()))).await;
assert_matches!(
res,
Err(Error::OAuth(error)) => {
assert_matches!(*error, OAuthError::SessionMismatch);
}
);
assert!(oauth.data().unwrap().authorization_data.lock().await.get(&state3).is_none());
Ok(())
}
#[async_test]
async fn test_oauth_session() -> anyhow::Result<()> {
let client = MockClientBuilder::new(None).unlogged().build().await;
let oauth = client.oauth();
let tokens = mock_session_tokens_with_refresh();
let session = mock_session(tokens.clone());
oauth.restore_session(session.clone(), RoomLoadSettings::default()).await?;
assert_eq!(client.session_tokens().unwrap(), tokens);
let user_session = oauth.user_session().unwrap();
assert_eq!(user_session.meta, session.user.meta);
assert_eq!(user_session.tokens, tokens);
let full_session = oauth.full_session().unwrap();
assert_eq!(full_session.client_id.as_str(), "test_client_id");
assert_eq!(full_session.user.meta, session.user.meta);
assert_eq!(full_session.user.tokens, tokens);
Ok(())
}
#[async_test]
async fn test_insecure_clients() -> anyhow::Result<()> {
let server = MatrixMockServer::new().await;
let server_url = server.uri();
server.mock_well_known().ok().expect(1..).named("well_known").mount().await;
server.mock_versions().ok().expect(1..).named("versions").mount().await;
let oauth_server = server.oauth();
oauth_server.mock_server_metadata().ok().expect(2..).named("server_metadata").mount().await;
oauth_server.mock_token().ok().expect(2).named("token").mount().await;
let prev_tokens = mock_prev_session_tokens_with_refresh();
let next_tokens = mock_session_tokens_with_refresh();
for client in [
Client::builder().homeserver_url(&server_url).build().await?,
Client::builder()
.insecure_server_name_no_tls(&ServerName::parse(
server_url.strip_prefix("http://").unwrap(),
)?)
.build()
.await?,
] {
let oauth = client.oauth();
oauth
.restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
.await?;
let mut session_changes = client.subscribe_to_session_changes();
oauth.refresh_access_token().await?;
assert_eq!(client.session_tokens().unwrap(), next_tokens);
assert_eq!(
session_changes.try_recv(),
Ok(SessionChange::TokensRefreshed),
"The session changes should be notified of the tokens refresh"
);
assert_eq!(
session_changes.try_recv(),
Err(TryRecvError::Empty),
"There should be no more session changes"
);
}
Ok(())
}
#[async_test]
async fn test_register_client() {
let server = MatrixMockServer::new().await;
let oauth_server = server.oauth();
let client = server.client_builder().unlogged().build().await;
let oauth = client.oauth();
let client_metadata = mock_client_metadata();
oauth_server
.mock_server_metadata()
.ok_without_registration()
.expect(1)
.named("metadata_without_registration")
.mount()
.await;
let result = oauth.register_client(&client_metadata).await;
assert_matches!(
result,
Err(OAuthError::ClientRegistration(OAuthClientRegistrationError::NotSupported))
);
server.verify_and_reset().await;
oauth_server
.mock_server_metadata()
.ok()
.expect(1)
.named("metadata_with_registration")
.mount()
.await;
oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
let response = oauth.register_client(&client_metadata).await.unwrap();
assert_eq!(response.client_id.as_str(), "test_client_id");
let auth_data = oauth.data().unwrap();
assert_eq!(auth_data.client_id, response.client_id);
}
#[async_test]
async fn test_management_url_cache() {
let server = MatrixMockServer::new().await;
let oauth_server = server.oauth();
oauth_server.mock_server_metadata().ok().expect(1).mount().await;
let client = server.client_builder().logged_in_with_oauth().build().await;
let oauth = client.oauth();
assert!(!client.inner.caches.server_metadata.lock().await.contains("SERVER_METADATA"));
let management_url = oauth
.account_management_url()
.await
.expect("We should be able to fetch the account management url");
assert!(management_url.is_some());
assert!(client.inner.caches.server_metadata.lock().await.contains("SERVER_METADATA"));
let management_url = oauth
.account_management_url()
.await
.expect("We should be able to fetch the account management url");
assert!(management_url.is_some());
}
#[async_test]
async fn test_server_metadata() {
let server = MatrixMockServer::new().await;
let client = server.client_builder().unlogged().build().await;
let oauth = client.oauth();
let error = oauth.server_metadata().await.unwrap_err();
assert!(error.is_not_supported());
let oauth_server = server.oauth();
oauth_server.mock_server_metadata().ok().expect(1).named("auth_metadata").mount().await;
oauth.server_metadata().await.unwrap();
}
#[async_test]
async fn test_client_registration_data() {
let server = MatrixMockServer::new().await;
let oauth_server = server.oauth();
let server_metadata = oauth_server.server_metadata();
let client = server.client_builder().unlogged().build().await;
let oauth = client.oauth();
let res = oauth.use_registration_data(&server_metadata, None).await;
assert_matches!(res, Err(OAuthError::NotRegistered));
assert_eq!(oauth.client_id(), None);
let registration_data = ClientRegistrationData {
metadata: mock_client_metadata(),
static_registrations: Some([(server_metadata.issuer.clone(), mock_client_id())].into()),
};
oauth.use_registration_data(&server_metadata, Some(®istration_data)).await.unwrap();
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
let registration_data = ClientRegistrationData {
metadata: mock_client_metadata(),
static_registrations: Some(
[(server_metadata.issuer.clone(), ClientId::new("other_client_id".to_owned()))].into(),
),
};
oauth.use_registration_data(&server_metadata, Some(®istration_data)).await.unwrap();
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
let client_metadata = mock_client_metadata();
let client = server.client_builder().unlogged().build().await;
let oauth = client.oauth();
oauth_server
.mock_registration()
.ok()
.mock_once()
.named("registration_with_metadata")
.mount()
.await;
oauth.use_registration_data(&server_metadata, Some(&client_metadata.into())).await.unwrap();
assert_eq!(oauth.client_id().map(|id| id.as_str()), Some("test_client_id"));
}