Skip to main content

matrix_sdk/authentication/oauth/
http_client.rs

1// Copyright 2025 Kévin Commaille
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! HTTP client and helpers for making OAuth 2.0 requests.
16
17use matrix_sdk_base::BoxFuture;
18use oauth2::{
19    AsyncHttpClient, ErrorResponse, HttpClientError, HttpRequest, HttpResponse, RequestTokenError,
20};
21use oauth2_reqwest::ReqwestClient;
22
23/// An HTTP client for making OAuth 2.0 requests.
24#[derive(Debug, Clone)]
25pub(super) struct OAuthHttpClient {
26    pub(super) inner: ReqwestClient,
27    /// Rewrite HTTPS requests to use HTTP instead.
28    ///
29    /// This is a workaround to bypass some checks that require an HTTPS URL,
30    /// but we can only mock HTTP URLs.
31    #[cfg(test)]
32    pub(super) insecure_rewrite_https_to_http: bool,
33}
34
35impl<'c> AsyncHttpClient<'c> for OAuthHttpClient {
36    type Error = HttpClientError<reqwest::Error>;
37
38    type Future = BoxFuture<'c, Result<HttpResponse, Self::Error>>;
39
40    fn call(&'c self, request: HttpRequest) -> Self::Future {
41        Box::pin(async move {
42            #[cfg(test)]
43            let request = if self.insecure_rewrite_https_to_http
44                && request.uri().scheme().is_some_and(|scheme| *scheme == http::uri::Scheme::HTTPS)
45            {
46                let mut request = request;
47
48                let mut uri_parts = request.uri().clone().into_parts();
49                uri_parts.scheme = Some(http::uri::Scheme::HTTP);
50                *request.uri_mut() = http::uri::Uri::from_parts(uri_parts)
51                    .expect("reconstructing URI from parts should work");
52
53                request
54            } else {
55                request
56            };
57
58            let response = self.inner.call(request).await?;
59
60            Ok(response)
61        })
62    }
63}
64
65/// Check the status code of the given HTTP response to identify errors.
66pub(super) fn check_http_response_status_code<T: ErrorResponse + 'static>(
67    http_response: &HttpResponse,
68) -> Result<(), RequestTokenError<HttpClientError<reqwest::Error>, T>> {
69    if http_response.status().as_u16() < 400 {
70        return Ok(());
71    }
72
73    let reason = http_response.body().as_slice();
74    let error = if reason.is_empty() {
75        RequestTokenError::Other("server returned an empty error response".to_owned())
76    } else {
77        match serde_json::from_slice(reason) {
78            Ok(error) => RequestTokenError::ServerResponse(error),
79            Err(error) => RequestTokenError::Other(error.to_string()),
80        }
81    };
82
83    Err(error)
84}
85
86/// Check that the server returned a response with a JSON `Content-Type`.
87pub(super) fn check_http_response_json_content_type<T: ErrorResponse + 'static>(
88    http_response: &HttpResponse,
89) -> Result<(), RequestTokenError<HttpClientError<reqwest::Error>, T>> {
90    let Some(content_type) = http_response.headers().get(http::header::CONTENT_TYPE) else {
91        return Ok(());
92    };
93
94    if content_type
95        .to_str()
96        // Check only the beginning of the content type, because there might be extra
97        // parameters, like a charset.
98        .is_ok_and(|ct| ct.to_lowercase().starts_with(mime::APPLICATION_JSON.essence_str()))
99    {
100        Ok(())
101    } else {
102        Err(RequestTokenError::Other(format!(
103            "unexpected response Content-Type: {content_type:?}, should be `{}`",
104            mime::APPLICATION_JSON
105        )))
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use assert_matches2::assert_matches;
112    use oauth2::{RequestTokenError, basic::BasicErrorResponse};
113
114    use super::{check_http_response_json_content_type, check_http_response_status_code};
115
116    #[test]
117    fn test_check_http_response_status_code() {
118        // OK
119        let response = http::Response::builder().status(200).body(Vec::<u8>::new()).unwrap();
120        assert_matches!(check_http_response_status_code::<BasicErrorResponse>(&response), Ok(()));
121
122        // Error without body.
123        let response = http::Response::builder().status(404).body(Vec::<u8>::new()).unwrap();
124        assert_matches!(
125            check_http_response_status_code::<BasicErrorResponse>(&response),
126            Err(RequestTokenError::Other(_))
127        );
128
129        // Error with invalid body.
130        let response =
131            http::Response::builder().status(404).body(b"invalid error format".to_vec()).unwrap();
132        assert_matches!(
133            check_http_response_status_code::<BasicErrorResponse>(&response),
134            Err(RequestTokenError::Other(_))
135        );
136
137        // Error with valid body.
138        let response = http::Response::builder()
139            .status(404)
140            .body(br#"{"error": "invalid_request"}"#.to_vec())
141            .unwrap();
142        assert_matches!(
143            check_http_response_status_code::<BasicErrorResponse>(&response),
144            Err(RequestTokenError::ServerResponse(_))
145        );
146    }
147
148    #[test]
149    fn test_check_http_response_json_content_type() {
150        // Valid content type.
151        let response = http::Response::builder()
152            .status(200)
153            .header(http::header::CONTENT_TYPE, "application/json")
154            .body(b"{}".to_vec())
155            .unwrap();
156        assert_matches!(
157            check_http_response_json_content_type::<BasicErrorResponse>(&response),
158            Ok(())
159        );
160
161        // Valid content type with charset.
162        let response = http::Response::builder()
163            .status(200)
164            .header(http::header::CONTENT_TYPE, "application/json; charset=utf-8")
165            .body(b"{}".to_vec())
166            .unwrap();
167        assert_matches!(
168            check_http_response_json_content_type::<BasicErrorResponse>(&response),
169            Ok(())
170        );
171
172        // Without content type.
173        let response = http::Response::builder().status(200).body(b"{}".to_vec()).unwrap();
174        assert_matches!(
175            check_http_response_json_content_type::<BasicErrorResponse>(&response),
176            Ok(())
177        );
178
179        // Wrong content type.
180        let response = http::Response::builder()
181            .status(200)
182            .header(http::header::CONTENT_TYPE, "text/html")
183            .body(b"<html><body><h1>HTML!</h1></body></html>".to_vec())
184            .unwrap();
185        assert_matches!(
186            check_http_response_json_content_type::<BasicErrorResponse>(&response),
187            Err(RequestTokenError::Other(_))
188        );
189    }
190}