matrix_sdk/authentication/oauth/
http_client.rs1use matrix_sdk_base::BoxFuture;
18use oauth2::{
19 AsyncHttpClient, ErrorResponse, HttpClientError, HttpRequest, HttpResponse, RequestTokenError,
20};
21use oauth2_reqwest::ReqwestClient;
22
23#[derive(Debug, Clone)]
25pub(super) struct OAuthHttpClient {
26 pub(super) inner: ReqwestClient,
27 #[cfg(test)]
32 pub(super) insecure_rewrite_https_to_http: bool,
33}
34
35impl<'c> AsyncHttpClient<'c> for OAuthHttpClient {
36 type Error = HttpClientError<reqwest::Error>;
37
38 type Future = BoxFuture<'c, Result<HttpResponse, Self::Error>>;
39
40 fn call(&'c self, request: HttpRequest) -> Self::Future {
41 Box::pin(async move {
42 #[cfg(test)]
43 let request = if self.insecure_rewrite_https_to_http
44 && request.uri().scheme().is_some_and(|scheme| *scheme == http::uri::Scheme::HTTPS)
45 {
46 let mut request = request;
47
48 let mut uri_parts = request.uri().clone().into_parts();
49 uri_parts.scheme = Some(http::uri::Scheme::HTTP);
50 *request.uri_mut() = http::uri::Uri::from_parts(uri_parts)
51 .expect("reconstructing URI from parts should work");
52
53 request
54 } else {
55 request
56 };
57
58 let response = self.inner.call(request).await?;
59
60 Ok(response)
61 })
62 }
63}
64
65pub(super) fn check_http_response_status_code<T: ErrorResponse + 'static>(
67 http_response: &HttpResponse,
68) -> Result<(), RequestTokenError<HttpClientError<reqwest::Error>, T>> {
69 if http_response.status().as_u16() < 400 {
70 return Ok(());
71 }
72
73 let reason = http_response.body().as_slice();
74 let error = if reason.is_empty() {
75 RequestTokenError::Other("server returned an empty error response".to_owned())
76 } else {
77 match serde_json::from_slice(reason) {
78 Ok(error) => RequestTokenError::ServerResponse(error),
79 Err(error) => RequestTokenError::Other(error.to_string()),
80 }
81 };
82
83 Err(error)
84}
85
86pub(super) fn check_http_response_json_content_type<T: ErrorResponse + 'static>(
88 http_response: &HttpResponse,
89) -> Result<(), RequestTokenError<HttpClientError<reqwest::Error>, T>> {
90 let Some(content_type) = http_response.headers().get(http::header::CONTENT_TYPE) else {
91 return Ok(());
92 };
93
94 if content_type
95 .to_str()
96 .is_ok_and(|ct| ct.to_lowercase().starts_with(mime::APPLICATION_JSON.essence_str()))
99 {
100 Ok(())
101 } else {
102 Err(RequestTokenError::Other(format!(
103 "unexpected response Content-Type: {content_type:?}, should be `{}`",
104 mime::APPLICATION_JSON
105 )))
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use assert_matches2::assert_matches;
112 use oauth2::{RequestTokenError, basic::BasicErrorResponse};
113
114 use super::{check_http_response_json_content_type, check_http_response_status_code};
115
116 #[test]
117 fn test_check_http_response_status_code() {
118 let response = http::Response::builder().status(200).body(Vec::<u8>::new()).unwrap();
120 assert_matches!(check_http_response_status_code::<BasicErrorResponse>(&response), Ok(()));
121
122 let response = http::Response::builder().status(404).body(Vec::<u8>::new()).unwrap();
124 assert_matches!(
125 check_http_response_status_code::<BasicErrorResponse>(&response),
126 Err(RequestTokenError::Other(_))
127 );
128
129 let response =
131 http::Response::builder().status(404).body(b"invalid error format".to_vec()).unwrap();
132 assert_matches!(
133 check_http_response_status_code::<BasicErrorResponse>(&response),
134 Err(RequestTokenError::Other(_))
135 );
136
137 let response = http::Response::builder()
139 .status(404)
140 .body(br#"{"error": "invalid_request"}"#.to_vec())
141 .unwrap();
142 assert_matches!(
143 check_http_response_status_code::<BasicErrorResponse>(&response),
144 Err(RequestTokenError::ServerResponse(_))
145 );
146 }
147
148 #[test]
149 fn test_check_http_response_json_content_type() {
150 let response = http::Response::builder()
152 .status(200)
153 .header(http::header::CONTENT_TYPE, "application/json")
154 .body(b"{}".to_vec())
155 .unwrap();
156 assert_matches!(
157 check_http_response_json_content_type::<BasicErrorResponse>(&response),
158 Ok(())
159 );
160
161 let response = http::Response::builder()
163 .status(200)
164 .header(http::header::CONTENT_TYPE, "application/json; charset=utf-8")
165 .body(b"{}".to_vec())
166 .unwrap();
167 assert_matches!(
168 check_http_response_json_content_type::<BasicErrorResponse>(&response),
169 Ok(())
170 );
171
172 let response = http::Response::builder().status(200).body(b"{}".to_vec()).unwrap();
174 assert_matches!(
175 check_http_response_json_content_type::<BasicErrorResponse>(&response),
176 Ok(())
177 );
178
179 let response = http::Response::builder()
181 .status(200)
182 .header(http::header::CONTENT_TYPE, "text/html")
183 .body(b"<html><body><h1>HTML!</h1></body></html>".to_vec())
184 .unwrap();
185 assert_matches!(
186 check_http_response_json_content_type::<BasicErrorResponse>(&response),
187 Err(RequestTokenError::Other(_))
188 );
189 }
190}