use matrix_sdk_base::BoxFuture;
use oauth2::{
AsyncHttpClient, ErrorResponse, HttpClientError, HttpRequest, HttpResponse, RequestTokenError,
};
#[derive(Debug, Clone)]
pub(super) struct OAuthHttpClient {
pub(super) inner: reqwest::Client,
#[cfg(test)]
pub(super) insecure_rewrite_https_to_http: bool,
}
impl<'c> AsyncHttpClient<'c> for OAuthHttpClient {
type Error = HttpClientError<reqwest::Error>;
type Future = BoxFuture<'c, Result<HttpResponse, Self::Error>>;
fn call(&'c self, request: HttpRequest) -> Self::Future {
Box::pin(async move {
#[cfg(test)]
let request = if self.insecure_rewrite_https_to_http
&& request.uri().scheme().is_some_and(|scheme| *scheme == http::uri::Scheme::HTTPS)
{
let mut request = request;
let mut uri_parts = request.uri().clone().into_parts();
uri_parts.scheme = Some(http::uri::Scheme::HTTP);
*request.uri_mut() = http::uri::Uri::from_parts(uri_parts)
.expect("reconstructing URI from parts should work");
request
} else {
request
};
let response = self.inner.call(request).await?;
Ok(response)
})
}
}
pub(super) fn check_http_response_status_code<T: ErrorResponse + 'static>(
http_response: &HttpResponse,
) -> Result<(), RequestTokenError<HttpClientError<reqwest::Error>, T>> {
if http_response.status().as_u16() < 400 {
return Ok(());
}
let reason = http_response.body().as_slice();
let error = if reason.is_empty() {
RequestTokenError::Other("server returned an empty error response".to_owned())
} else {
match serde_json::from_slice(reason) {
Ok(error) => RequestTokenError::ServerResponse(error),
Err(error) => RequestTokenError::Other(error.to_string()),
}
};
Err(error)
}
pub(super) fn check_http_response_json_content_type<T: ErrorResponse + 'static>(
http_response: &HttpResponse,
) -> Result<(), RequestTokenError<HttpClientError<reqwest::Error>, T>> {
let Some(content_type) = http_response.headers().get(http::header::CONTENT_TYPE) else {
return Ok(());
};
if content_type
.to_str()
.is_ok_and(|ct| ct.to_lowercase().starts_with(mime::APPLICATION_JSON.essence_str()))
{
Ok(())
} else {
Err(RequestTokenError::Other(format!(
"unexpected response Content-Type: {content_type:?}, should be `{}`",
mime::APPLICATION_JSON
)))
}
}
#[cfg(test)]
mod tests {
use assert_matches2::assert_matches;
use oauth2::{RequestTokenError, basic::BasicErrorResponse};
use super::{check_http_response_json_content_type, check_http_response_status_code};
#[test]
fn test_check_http_response_status_code() {
let response = http::Response::builder().status(200).body(Vec::<u8>::new()).unwrap();
assert_matches!(check_http_response_status_code::<BasicErrorResponse>(&response), Ok(()));
let response = http::Response::builder().status(404).body(Vec::<u8>::new()).unwrap();
assert_matches!(
check_http_response_status_code::<BasicErrorResponse>(&response),
Err(RequestTokenError::Other(_))
);
let response =
http::Response::builder().status(404).body(b"invalid error format".to_vec()).unwrap();
assert_matches!(
check_http_response_status_code::<BasicErrorResponse>(&response),
Err(RequestTokenError::Other(_))
);
let response = http::Response::builder()
.status(404)
.body(br#"{"error": "invalid_request"}"#.to_vec())
.unwrap();
assert_matches!(
check_http_response_status_code::<BasicErrorResponse>(&response),
Err(RequestTokenError::ServerResponse(_))
);
}
#[test]
fn test_check_http_response_json_content_type() {
let response = http::Response::builder()
.status(200)
.header(http::header::CONTENT_TYPE, "application/json")
.body(b"{}".to_vec())
.unwrap();
assert_matches!(
check_http_response_json_content_type::<BasicErrorResponse>(&response),
Ok(())
);
let response = http::Response::builder()
.status(200)
.header(http::header::CONTENT_TYPE, "application/json; charset=utf-8")
.body(b"{}".to_vec())
.unwrap();
assert_matches!(
check_http_response_json_content_type::<BasicErrorResponse>(&response),
Ok(())
);
let response = http::Response::builder().status(200).body(b"{}".to_vec()).unwrap();
assert_matches!(
check_http_response_json_content_type::<BasicErrorResponse>(&response),
Ok(())
);
let response = http::Response::builder()
.status(200)
.header(http::header::CONTENT_TYPE, "text/html")
.body(b"<html><body><h1>HTML!</h1></body></html>".to_vec())
.unwrap();
assert_matches!(
check_http_response_json_content_type::<BasicErrorResponse>(&response),
Err(RequestTokenError::Other(_))
);
}
}