1use std::{
16 fmt::Debug,
17 mem,
18 sync::atomic::{AtomicU64, Ordering},
19 time::Duration,
20};
21
22use backon::{ExponentialBuilder, Retryable};
23use bytes::Bytes;
24use bytesize::ByteSize;
25use eyeball::SharedObservable;
26use http::header::CONTENT_LENGTH;
27use reqwest::{Certificate, tls};
28use ruma::api::{IncomingResponse, OutgoingRequest, error::FromHttpResponseError};
29use tracing::{debug, info, warn};
30
31use super::{DEFAULT_REQUEST_TIMEOUT, HttpClient, TransmissionProgress, response_to_http_response};
32use crate::{
33 config::RequestConfig,
34 error::{HttpError, RetryKind},
35};
36
37impl HttpClient {
38 pub(super) async fn send_request<R>(
39 &self,
40 request: http::Request<Bytes>,
41 config: RequestConfig,
42 send_progress: SharedObservable<TransmissionProgress>,
43 ) -> Result<R::IncomingResponse, HttpError>
44 where
45 R: OutgoingRequest + Debug,
46 HttpError: From<FromHttpResponseError<R::EndpointError>>,
47 {
48 let backoff = ExponentialBuilder::new()
51 .with_min_delay(Duration::from_millis(500))
52 .with_max_delay(Duration::from_secs(60))
53 .with_total_delay(Some(Duration::from_secs(15 * 60)))
54 .without_max_times();
55
56 let backoff = if let Some(max_delay) = config.max_retry_time {
58 backoff.with_max_delay(max_delay)
59 } else {
60 backoff
61 };
62
63 let backoff = if let Some(max_times) = config.retry_limit {
64 backoff.with_max_times(max_times.saturating_sub(1))
67 } else {
68 backoff
69 };
70
71 let retry_count = AtomicU64::new(1);
72
73 let send_request = || {
74 let send_progress = send_progress.clone();
75
76 async {
77 let num_attempt = retry_count.fetch_add(1, Ordering::SeqCst);
78 debug!(num_attempt, "Sending request");
79 let before = ruma::time::Instant::now();
80
81 let response =
82 send_request(&self.inner, &request, config.timeout, send_progress).await?;
83
84 let request_duration = ruma::time::Instant::now().saturating_duration_since(before);
85
86 let status_code = response.status();
87 let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX));
88 tracing::Span::current()
89 .record("status", status_code.as_u16())
90 .record("response_size", response_size.display().si_short().to_string())
91 .record("request_duration", tracing::field::debug(request_duration));
92
93 for (header_name, header_value) in response.headers() {
96 let header_name = header_name.as_str().to_lowercase();
97
98 if header_name == "x-sentry-event-id" {
101 tracing::Span::current()
102 .record("sentry_event_id", header_value.to_str().unwrap_or("<???>"));
103 }
104 }
105
106 R::IncomingResponse::try_from_http_response(response).map_err(HttpError::from)
107 }
108 };
109
110 let has_retry_limit = config.retry_limit.is_some();
111
112 send_request
113 .retry(backoff)
114 .adjust(|err, default_timeout| {
115 match err.retry_kind() {
116 RetryKind::Transient { retry_after } => {
117 if default_timeout.is_some() {
124 retry_after.or(default_timeout)
125 } else {
126 None
127 }
128 }
129 RetryKind::Permanent => None,
130 RetryKind::NetworkFailure => {
131 if has_retry_limit { default_timeout } else { None }
135 }
136 }
137 })
138 .await
139 }
140}
141
142#[cfg(not(target_family = "wasm"))]
143#[derive(Clone, Debug)]
144pub(crate) struct HttpSettings {
145 pub(crate) disable_ssl_verification: bool,
146 pub(crate) proxy: Option<String>,
147 pub(crate) user_agent: Option<String>,
148 pub(crate) timeout: Option<Duration>,
149 pub(crate) read_timeout: Option<Duration>,
150 pub(crate) additional_root_certificates: Vec<Certificate>,
151 pub(crate) disable_built_in_root_certificates: bool,
152}
153
154#[cfg(not(target_family = "wasm"))]
155impl Default for HttpSettings {
156 fn default() -> Self {
157 Self {
158 disable_ssl_verification: false,
159 proxy: None,
160 user_agent: None,
161 timeout: Some(DEFAULT_REQUEST_TIMEOUT),
162 read_timeout: None,
163 additional_root_certificates: Default::default(),
164 disable_built_in_root_certificates: false,
165 }
166 }
167}
168
169#[cfg(not(target_family = "wasm"))]
170impl HttpSettings {
171 pub(crate) fn make_client(&self) -> Result<reqwest::Client, HttpError> {
173 let user_agent = self.user_agent.clone().unwrap_or_else(|| "matrix-rust-sdk".to_owned());
174 let mut http_client = reqwest::Client::builder()
175 .user_agent(user_agent)
176 .min_tls_version(tls::Version::TLS_1_2);
179
180 if let Some(timeout) = self.timeout {
181 http_client = http_client.timeout(timeout);
182 }
183
184 if let Some(read_timeout) = self.read_timeout {
185 http_client = http_client.read_timeout(read_timeout);
186 }
187
188 if self.disable_ssl_verification {
189 warn!("SSL verification disabled in the HTTP client!");
190 http_client = http_client.danger_accept_invalid_certs(true)
191 }
192
193 if !self.additional_root_certificates.is_empty() {
194 info!(
195 "Adding {} additional root certificates to the HTTP client",
196 self.additional_root_certificates.len()
197 );
198
199 for cert in &self.additional_root_certificates {
200 http_client = http_client.add_root_certificate(cert.clone());
201 }
202 }
203
204 if self.disable_built_in_root_certificates {
205 info!("Built-in root certificates disabled in the HTTP client.");
206 http_client = http_client.tls_built_in_root_certs(false);
207 }
208
209 if let Some(p) = &self.proxy {
210 info!(proxy_url = p, "Setting the proxy for the HTTP client");
211 http_client = http_client.proxy(reqwest::Proxy::all(p.as_str())?);
212 }
213
214 Ok(http_client.build()?)
215 }
216}
217
218pub(super) async fn send_request(
219 client: &reqwest::Client,
220 request: &http::Request<Bytes>,
221 timeout: Option<Duration>,
222 send_progress: SharedObservable<TransmissionProgress>,
223) -> Result<http::Response<Bytes>, HttpError> {
224 use std::convert::Infallible;
225
226 use futures_util::stream;
227
228 let request = request.clone();
229 let request = {
230 let mut request = if send_progress.subscriber_count() != 0 {
231 let content_length = request.body().len();
232 send_progress.update(|p| p.total += content_length);
233
234 tokio::task::yield_now().await;
238
239 let mut req = reqwest::Request::try_from(request.map(|body| {
240 let chunks = stream::iter(BytesChunks::new(body, 8192).map(
241 move |chunk| -> Result<_, Infallible> {
242 send_progress.update(|p| p.current += chunk.len());
243 Ok(chunk)
244 },
245 ));
246 reqwest::Body::wrap_stream(chunks)
247 }))?;
248
249 req.headers_mut().insert(CONTENT_LENGTH, content_length.into());
253
254 req
255 } else {
256 reqwest::Request::try_from(request)?
257 };
258
259 *request.timeout_mut() = timeout;
260 request
261 };
262
263 let response = client.execute(request).await?;
264 Ok(response_to_http_response(response).await?)
265}
266
267struct BytesChunks {
268 bytes: Bytes,
269 size: usize,
270}
271
272impl BytesChunks {
273 fn new(bytes: Bytes, size: usize) -> Self {
274 assert_ne!(size, 0);
275 Self { bytes, size }
276 }
277}
278
279impl Iterator for BytesChunks {
280 type Item = Bytes;
281
282 fn next(&mut self) -> Option<Self::Item> {
283 if self.bytes.is_empty() {
284 None
285 } else if self.bytes.len() < self.size {
286 Some(mem::take(&mut self.bytes))
287 } else {
288 Some(self.bytes.split_to(self.size))
289 }
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use bytes::Bytes;
296
297 use super::BytesChunks;
298
299 #[test]
300 fn test_bytes_chunks() {
301 let bytes = Bytes::new();
302 assert!(BytesChunks::new(bytes, 1).collect::<Vec<_>>().is_empty());
303
304 let bytes = Bytes::from_iter([1, 2]);
305 assert_eq!(BytesChunks::new(bytes, 2).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
306
307 let bytes = Bytes::from_iter([1, 2]);
308 assert_eq!(BytesChunks::new(bytes, 3).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
309
310 let bytes = Bytes::from_iter([1, 2, 3]);
311 assert_eq!(
312 BytesChunks::new(bytes, 1).collect::<Vec<_>>(),
313 [Bytes::from_iter([1]), Bytes::from_iter([2]), Bytes::from_iter([3])]
314 );
315
316 let bytes = Bytes::from_iter([1, 2, 3]);
317 assert_eq!(
318 BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
319 [Bytes::from_iter([1, 2]), Bytes::from_iter([3])]
320 );
321
322 let bytes = Bytes::from_iter([1, 2, 3, 4]);
323 assert_eq!(
324 BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
325 [Bytes::from_iter([1, 2]), Bytes::from_iter([3, 4])]
326 );
327 }
328}