1use std::{
16 any::type_name,
17 fmt::Debug,
18 num::NonZeroUsize,
19 sync::{
20 atomic::{AtomicU64, Ordering},
21 Arc,
22 },
23 time::Duration,
24};
25
26use bytes::{Bytes, BytesMut};
27use bytesize::ByteSize;
28use eyeball::SharedObservable;
29use http::Method;
30use ruma::api::{
31 error::{FromHttpResponseError, IntoHttpError},
32 AuthScheme, MatrixVersion, OutgoingRequest, SendAccessToken,
33};
34use tokio::sync::{Semaphore, SemaphorePermit};
35use tracing::{debug, field::debug, instrument, trace};
36
37use crate::{config::RequestConfig, error::HttpError};
38
39#[cfg(not(target_arch = "wasm32"))]
40mod native;
41#[cfg(target_arch = "wasm32")]
42mod wasm;
43
44#[cfg(not(target_arch = "wasm32"))]
45pub(crate) use native::HttpSettings;
46
47pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
48
49#[derive(Clone, Debug)]
50struct MaybeSemaphore(Arc<Option<Semaphore>>);
51
52#[allow(dead_code)] struct MaybeSemaphorePermit<'a>(Option<SemaphorePermit<'a>>);
54
55impl MaybeSemaphore {
56 fn new(max: Option<NonZeroUsize>) -> Self {
57 let inner = max.map(|i| Semaphore::new(i.into()));
58 MaybeSemaphore(Arc::new(inner))
59 }
60
61 async fn acquire(&self) -> MaybeSemaphorePermit<'_> {
62 match self.0.as_ref() {
63 Some(inner) => {
64 MaybeSemaphorePermit(inner.acquire().await.ok())
67 }
68 None => MaybeSemaphorePermit(None),
69 }
70 }
71}
72
73#[derive(Clone, Debug)]
74pub(crate) struct HttpClient {
75 pub(crate) inner: reqwest::Client,
76 pub(crate) request_config: RequestConfig,
77 concurrent_request_semaphore: MaybeSemaphore,
78 next_request_id: Arc<AtomicU64>,
79}
80
81impl HttpClient {
82 pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
83 HttpClient {
84 inner,
85 request_config,
86 concurrent_request_semaphore: MaybeSemaphore::new(
87 request_config.max_concurrent_requests,
88 ),
89 next_request_id: AtomicU64::new(0).into(),
90 }
91 }
92
93 fn get_request_id(&self) -> String {
94 let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
95 format!("REQ-{request_id}")
96 }
97
98 fn serialize_request<R>(
99 &self,
100 request: R,
101 config: RequestConfig,
102 homeserver: String,
103 access_token: Option<&str>,
104 server_versions: &[MatrixVersion],
105 ) -> Result<http::Request<Bytes>, IntoHttpError>
106 where
107 R: OutgoingRequest + Debug,
108 {
109 trace!(request_type = type_name::<R>(), "Serializing request");
110
111 let server_versions = if config.force_matrix_version.is_some() {
112 config.force_matrix_version.as_slice()
113 } else {
114 server_versions
115 };
116
117 let send_access_token = match access_token {
118 Some(access_token) => {
119 if config.force_auth {
120 SendAccessToken::Always(access_token)
121 } else {
122 SendAccessToken::IfRequired(access_token)
123 }
124 }
125 None => SendAccessToken::None,
126 };
127
128 let request = request
129 .try_into_http_request::<BytesMut>(&homeserver, send_access_token, server_versions)?
130 .map(|body| body.freeze());
131
132 Ok(request)
133 }
134
135 #[allow(clippy::too_many_arguments)]
136 #[instrument(
137 skip(self, request, config, homeserver, access_token, server_versions, send_progress),
138 fields(uri, method, request_size, request_id, status, response_size, sentry_event_id)
139 )]
140 pub async fn send<R>(
141 &self,
142 request: R,
143 config: Option<RequestConfig>,
144 homeserver: String,
145 access_token: Option<&str>,
146 server_versions: &[MatrixVersion],
147 send_progress: SharedObservable<TransmissionProgress>,
148 ) -> Result<R::IncomingResponse, HttpError>
149 where
150 R: OutgoingRequest + Debug,
151 HttpError: From<FromHttpResponseError<R::EndpointError>>,
152 {
153 let config = match config {
154 Some(config) => config,
155 None => self.request_config,
156 };
157
158 let request = {
161 let request_id = self.get_request_id();
162 let span = tracing::Span::current();
163
164 span.record("config", debug(config)).record("request_id", request_id);
167
168 let auth_scheme = R::METADATA.authentication;
169 match auth_scheme {
170 AuthScheme::AccessToken
171 | AuthScheme::AccessTokenOptional
172 | AuthScheme::AppserviceToken
173 | AuthScheme::None => {}
174 AuthScheme::ServerSignatures => {
175 return Err(HttpError::NotClientRequest);
176 }
177 }
178
179 let request = self
180 .serialize_request(request, config, homeserver, access_token, server_versions)
181 .map_err(HttpError::IntoHttp)?;
182
183 let method = request.method();
184
185 let mut uri_parts = request.uri().clone().into_parts();
186 if let Some(path_and_query) = &mut uri_parts.path_and_query {
187 *path_and_query =
188 path_and_query.path().try_into().expect("path is valid PathAndQuery");
189 }
190 let uri = http::Uri::from_parts(uri_parts).expect("created from valid URI");
191
192 span.record("method", debug(method)).record("uri", uri.to_string());
193
194 if [Method::POST, Method::PUT, Method::PATCH].contains(method) {
197 let request_size = request.body().len().try_into().unwrap_or(u64::MAX);
198 span.record(
199 "request_size",
200 ByteSize(request_size).display().si_short().to_string(),
201 );
202 }
203
204 request
205 };
206
207 let _handle = self.concurrent_request_semaphore.acquire().await;
209
210 match Box::pin(self.send_request::<R>(request, config, send_progress)).await {
213 Ok(response) => {
214 debug!("Got response");
215 Ok(response)
216 }
217 Err(e) => {
218 debug!("Error while sending request: {e:?}");
219 Err(e)
220 }
221 }
222 }
223}
224
225#[derive(Clone, Copy, Debug, Default)]
227pub struct TransmissionProgress {
228 pub current: usize,
230 pub total: usize,
232}
233
234async fn response_to_http_response(
235 mut response: reqwest::Response,
236) -> Result<http::Response<Bytes>, reqwest::Error> {
237 let status = response.status();
238
239 let mut http_builder = http::Response::builder().status(status);
240 let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
241
242 for (k, v) in response.headers_mut().drain() {
243 if let Some(key) = k {
244 headers.insert(key, v);
245 }
246 }
247
248 let body = response.bytes().await?;
249
250 Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
251}
252
253#[cfg(all(test, not(target_arch = "wasm32")))]
254mod tests {
255 use std::{
256 num::NonZeroUsize,
257 sync::{
258 atomic::{AtomicU8, Ordering},
259 Arc,
260 },
261 time::Duration,
262 };
263
264 use matrix_sdk_test::{async_test, test_json};
265 use wiremock::{
266 matchers::{method, path},
267 Mock, Request, ResponseTemplate,
268 };
269
270 use crate::{
271 http_client::RequestConfig,
272 test_utils::{set_client_session, test_client_builder_with_server},
273 };
274
275 #[async_test]
276 async fn test_ensure_concurrent_request_limit_is_observed() {
277 let (client_builder, server) = test_client_builder_with_server().await;
278 let client = client_builder
279 .request_config(RequestConfig::default().max_concurrent_requests(NonZeroUsize::new(5)))
280 .build()
281 .await
282 .unwrap();
283
284 set_client_session(&client).await;
285
286 let counter = Arc::new(AtomicU8::new(0));
287 let inner_counter = counter.clone();
288
289 Mock::given(method("GET"))
290 .and(path("/_matrix/client/versions"))
291 .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
292 .mount(&server)
293 .await;
294
295 Mock::given(method("GET"))
296 .and(path("_matrix/client/r0/account/whoami"))
297 .respond_with(move |_req: &Request| {
298 inner_counter.fetch_add(1, Ordering::SeqCst);
299 ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
301 })
302 .mount(&server)
303 .await;
304
305 let bg_task = tokio::spawn(async move {
306 futures_util::future::join_all((0..10).map(|_| client.whoami())).await
307 });
308
309 tokio::time::sleep(Duration::from_millis(300)).await;
311
312 assert_eq!(
313 counter.load(Ordering::SeqCst),
314 5,
315 "More requests passed than the limit we configured"
316 );
317 bg_task.abort();
318 }
319
320 #[async_test]
321 async fn test_ensure_no_max_concurrent_request_does_not_limit() {
322 let (client_builder, server) = test_client_builder_with_server().await;
323 let client = client_builder
324 .request_config(RequestConfig::default().max_concurrent_requests(None))
325 .build()
326 .await
327 .unwrap();
328
329 set_client_session(&client).await;
330
331 let counter = Arc::new(AtomicU8::new(0));
332 let inner_counter = counter.clone();
333
334 Mock::given(method("GET"))
335 .and(path("/_matrix/client/versions"))
336 .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
337 .mount(&server)
338 .await;
339
340 Mock::given(method("GET"))
341 .and(path("_matrix/client/r0/account/whoami"))
342 .respond_with(move |_req: &Request| {
343 inner_counter.fetch_add(1, Ordering::SeqCst);
344 ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
345 })
346 .mount(&server)
347 .await;
348
349 let bg_task = tokio::spawn(async move {
350 futures_util::future::join_all((0..254).map(|_| client.whoami())).await
351 });
352
353 tokio::time::sleep(Duration::from_secs(1)).await;
355
356 assert_eq!(counter.load(Ordering::SeqCst), 254, "Not all requests passed through");
357 bg_task.abort();
358 }
359}