1#[cfg(target_os = "android")]
16use std::sync::Arc;
17use std::{
18 fmt::Debug,
19 mem,
20 sync::atomic::{AtomicU64, Ordering},
21 time::Duration,
22};
23
24use backon::{ExponentialBuilder, Retryable};
25use bytes::Bytes;
26use bytesize::ByteSize;
27use eyeball::SharedObservable;
28use http::header::CONTENT_LENGTH;
29#[cfg(not(target_family = "wasm"))]
30use reqwest::Certificate;
31#[cfg(target_os = "android")]
32use reqwest::ClientBuilder;
33use reqwest::tls;
34use ruma::api::{IncomingResponse, OutgoingRequest, error::FromHttpResponseError};
35#[cfg(target_os = "android")]
36use rustls::{RootCertStore, client::WebPkiServerVerifier};
37#[cfg(target_os = "android")]
38use rustls_pki_types::CertificateDer;
39use tracing::{debug, info, warn};
40
41use super::{DEFAULT_REQUEST_TIMEOUT, HttpClient, TransmissionProgress, response_to_http_response};
42use crate::{
43 config::RequestConfig,
44 error::{HttpError, RetryKind},
45};
46
47impl HttpClient {
48 pub(super) async fn send_request<R>(
49 &self,
50 request: http::Request<Bytes>,
51 config: RequestConfig,
52 send_progress: SharedObservable<TransmissionProgress>,
53 ) -> Result<R::IncomingResponse, HttpError>
54 where
55 R: OutgoingRequest + Debug,
56 HttpError: From<FromHttpResponseError<R::EndpointError>>,
57 {
58 let backoff = ExponentialBuilder::new()
61 .with_min_delay(Duration::from_millis(500))
62 .with_max_delay(Duration::from_secs(60))
63 .with_total_delay(Some(Duration::from_secs(15 * 60)))
64 .without_max_times();
65
66 let backoff = if let Some(max_delay) = config.max_retry_time {
68 backoff.with_max_delay(max_delay)
69 } else {
70 backoff
71 };
72
73 let backoff = if let Some(max_times) = config.retry_limit {
74 backoff.with_max_times(max_times.saturating_sub(1))
77 } else {
78 backoff
79 };
80
81 let retry_count = AtomicU64::new(1);
82
83 let send_request = || {
84 let send_progress = send_progress.clone();
85
86 async {
87 let num_attempt = retry_count.fetch_add(1, Ordering::SeqCst);
88 debug!(num_attempt, "Sending request");
89 let before = ruma::time::Instant::now();
90
91 let response =
92 send_request(&self.inner, &request, config.timeout, send_progress).await?;
93
94 let request_duration = ruma::time::Instant::now().saturating_duration_since(before);
95
96 let status_code = response.status();
97 let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX));
98 tracing::Span::current()
99 .record("status", status_code.as_u16())
100 .record("response_size", response_size.display().si_short().to_string())
101 .record("request_duration", tracing::field::debug(request_duration));
102
103 for (header_name, header_value) in response.headers() {
106 let header_name = header_name.as_str().to_lowercase();
107
108 if header_name == "x-sentry-event-id" {
111 tracing::Span::current()
112 .record("sentry_event_id", header_value.to_str().unwrap_or("<???>"));
113 }
114 }
115
116 R::IncomingResponse::try_from_http_response(response).map_err(HttpError::from)
117 }
118 };
119
120 let has_retry_limit = config.retry_limit.is_some();
121
122 send_request
123 .retry(backoff)
124 .adjust(|err, backon_suggested_timeout| {
125 match err.retry_kind() {
126 RetryKind::Transient { retry_after } => {
127 if backon_suggested_timeout.is_some() {
135 retry_after.or(backon_suggested_timeout)
136 } else {
137 None
138 }
139 }
140 RetryKind::Permanent => None,
141 RetryKind::NetworkFailure => {
142 if has_retry_limit { backon_suggested_timeout } else { None }
146 }
147 }
148 })
149 .await
150 }
151}
152
153#[cfg(not(target_family = "wasm"))]
154#[derive(Clone, Debug)]
155pub(crate) struct HttpSettings {
156 pub(crate) disable_ssl_verification: bool,
157 pub(crate) proxy: Option<String>,
158 pub(crate) user_agent: Option<String>,
159 pub(crate) timeout: Option<Duration>,
160 pub(crate) read_timeout: Option<Duration>,
161 pub(crate) additional_root_certificates: Vec<Certificate>,
162 #[cfg(target_os = "android")]
163 pub(crate) additional_raw_root_certificates: Vec<Vec<u8>>,
164 pub(crate) disable_built_in_root_certificates: bool,
165}
166
167#[cfg(not(target_family = "wasm"))]
168impl Default for HttpSettings {
169 fn default() -> Self {
170 Self {
171 disable_ssl_verification: false,
172 proxy: None,
173 user_agent: None,
174 timeout: Some(DEFAULT_REQUEST_TIMEOUT),
175 read_timeout: None,
176 additional_root_certificates: Default::default(),
177 #[cfg(target_os = "android")]
178 additional_raw_root_certificates: Default::default(),
179 disable_built_in_root_certificates: false,
180 }
181 }
182}
183
184#[cfg(not(target_family = "wasm"))]
185impl HttpSettings {
186 pub(crate) fn make_client(&self) -> Result<reqwest::Client, HttpError> {
188 let user_agent = self.user_agent.clone().unwrap_or_else(|| "matrix-rust-sdk".to_owned());
189 let mut http_client = reqwest::Client::builder()
190 .user_agent(user_agent)
191 .min_tls_version(tls::Version::TLS_1_2);
194
195 if let Some(timeout) = self.timeout {
196 http_client = http_client.timeout(timeout);
197 }
198
199 if let Some(read_timeout) = self.read_timeout {
200 http_client = http_client.read_timeout(read_timeout);
201 }
202
203 #[cfg(target_os = "android")]
208 {
209 http_client = self.android_setup_webkpi_verifier(http_client)?;
210 }
211
212 if self.disable_ssl_verification {
213 warn!("SSL verification disabled in the HTTP client!");
214 http_client = http_client.danger_accept_invalid_certs(true);
215 }
216
217 http_client = if self.disable_built_in_root_certificates {
218 info!("Built-in root certificates disabled in the HTTP client.");
219 http_client.tls_certs_only(self.additional_root_certificates.clone())
220 } else {
221 http_client.tls_certs_merge(self.additional_root_certificates.clone())
222 };
223
224 if let Some(p) = &self.proxy {
225 info!(proxy_url = p, "Setting the proxy for the HTTP client");
226 http_client = http_client.proxy(reqwest::Proxy::all(p.as_str())?);
227 }
228
229 Ok(http_client.build()?)
230 }
231
232 #[cfg(target_os = "android")]
233 fn android_setup_webkpi_verifier(
234 &self,
235 client_builder: ClientBuilder,
236 ) -> Result<ClientBuilder, HttpError> {
237 if !self.disable_ssl_verification {
238 let mut root_store = RootCertStore::empty();
239
240 if self.disable_built_in_root_certificates {
241 info!("Built-in root certificates disabled in the HTTP client.");
242 } else {
243 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
245
246 let native_certs = rustls_native_certs::load_native_certs().certs;
248 root_store.add_parsable_certificates(native_certs);
249 }
250
251 if !self.additional_raw_root_certificates.is_empty() {
252 let mut additional_certs = Vec::new();
253
254 warn!(
255 "Adding {} extra user certificates",
256 self.additional_raw_root_certificates.len()
257 );
258
259 for certificate in self.additional_raw_root_certificates.iter() {
260 additional_certs.push(CertificateDer::from_slice(certificate));
261 }
262
263 root_store.add_parsable_certificates(additional_certs);
264 }
265
266 let verifier = WebPkiServerVerifier::builder(Arc::new(root_store))
267 .build()
268 .map_err(HttpError::VerifierBuilder)?;
269
270 let config = rustls::ClientConfig::builder()
271 .with_webpki_verifier(verifier)
272 .with_no_client_auth();
273 Ok(client_builder.tls_backend_preconfigured(config))
274 } else {
275 Ok(client_builder)
276 }
277 }
278}
279
280pub(super) async fn send_request(
281 client: &reqwest::Client,
282 request: &http::Request<Bytes>,
283 timeout: Option<Duration>,
284 send_progress: SharedObservable<TransmissionProgress>,
285) -> Result<http::Response<Bytes>, HttpError> {
286 use std::convert::Infallible;
287
288 use futures_util::stream;
289
290 let request = request.clone();
291 let request = {
292 let mut request = if send_progress.subscriber_count() != 0 {
293 let content_length = request.body().len();
294 send_progress.update(|p| p.total += content_length);
295
296 tokio::task::yield_now().await;
300
301 let mut req = reqwest::Request::try_from(request.map(|body| {
302 let chunks = stream::iter(BytesChunks::new(body, 8192).map(
303 move |chunk| -> Result<_, Infallible> {
304 send_progress.update(|p| p.current += chunk.len());
305 Ok(chunk)
306 },
307 ));
308 reqwest::Body::wrap_stream(chunks)
309 }))?;
310
311 req.headers_mut().insert(CONTENT_LENGTH, content_length.into());
315
316 req
317 } else {
318 reqwest::Request::try_from(request)?
319 };
320
321 *request.timeout_mut() = timeout;
322 request
323 };
324
325 let response = client.execute(request).await?;
326 Ok(response_to_http_response(response).await?)
327}
328
329struct BytesChunks {
330 bytes: Bytes,
331 size: usize,
332}
333
334impl BytesChunks {
335 fn new(bytes: Bytes, size: usize) -> Self {
336 assert_ne!(size, 0);
337 Self { bytes, size }
338 }
339}
340
341impl Iterator for BytesChunks {
342 type Item = Bytes;
343
344 fn next(&mut self) -> Option<Self::Item> {
345 if self.bytes.is_empty() {
346 None
347 } else if self.bytes.len() < self.size {
348 Some(mem::take(&mut self.bytes))
349 } else {
350 Some(self.bytes.split_to(self.size))
351 }
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use bytes::Bytes;
358
359 use super::BytesChunks;
360
361 #[test]
362 fn test_bytes_chunks() {
363 let bytes = Bytes::new();
364 assert!(BytesChunks::new(bytes, 1).collect::<Vec<_>>().is_empty());
365
366 let bytes = Bytes::from_iter([1, 2]);
367 assert_eq!(BytesChunks::new(bytes, 2).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
368
369 let bytes = Bytes::from_iter([1, 2]);
370 assert_eq!(BytesChunks::new(bytes, 3).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
371
372 let bytes = Bytes::from_iter([1, 2, 3]);
373 assert_eq!(
374 BytesChunks::new(bytes, 1).collect::<Vec<_>>(),
375 [Bytes::from_iter([1]), Bytes::from_iter([2]), Bytes::from_iter([3])]
376 );
377
378 let bytes = Bytes::from_iter([1, 2, 3]);
379 assert_eq!(
380 BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
381 [Bytes::from_iter([1, 2]), Bytes::from_iter([3])]
382 );
383
384 let bytes = Bytes::from_iter([1, 2, 3, 4]);
385 assert_eq!(
386 BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
387 [Bytes::from_iter([1, 2]), Bytes::from_iter([3, 4])]
388 );
389 }
390}