matrix_sdk/http_client/
native.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    fmt::Debug,
17    mem,
18    sync::atomic::{AtomicU64, Ordering},
19    time::Duration,
20};
21
22use backon::{ExponentialBuilder, Retryable};
23use bytes::Bytes;
24use bytesize::ByteSize;
25use eyeball::SharedObservable;
26use http::header::CONTENT_LENGTH;
27use reqwest::{Certificate, tls};
28use ruma::api::{IncomingResponse, OutgoingRequest, error::FromHttpResponseError};
29use tracing::{debug, info, warn};
30
31use super::{DEFAULT_REQUEST_TIMEOUT, HttpClient, TransmissionProgress, response_to_http_response};
32use crate::{
33    config::RequestConfig,
34    error::{HttpError, RetryKind},
35};
36
37impl HttpClient {
38    pub(super) async fn send_request<R>(
39        &self,
40        request: http::Request<Bytes>,
41        config: RequestConfig,
42        send_progress: SharedObservable<TransmissionProgress>,
43    ) -> Result<R::IncomingResponse, HttpError>
44    where
45        R: OutgoingRequest + Debug,
46        HttpError: From<FromHttpResponseError<R::EndpointError>>,
47    {
48        // These values were picked because we used to use the `backoff` crate, those
49        // were defined here: https://docs.rs/backoff/0.4.0/backoff/default/index.html
50        let backoff = ExponentialBuilder::new()
51            .with_min_delay(Duration::from_millis(500))
52            .with_max_delay(Duration::from_secs(60))
53            .with_total_delay(Some(Duration::from_secs(15 * 60)))
54            .without_max_times();
55
56        // Let's now apply any override the user or the SDK might have set.
57        let backoff = if let Some(max_delay) = config.max_retry_time {
58            backoff.with_max_delay(max_delay)
59        } else {
60            backoff
61        };
62
63        let backoff = if let Some(max_times) = config.retry_limit {
64            // Backon behaves a bit differently to our own handcrafted max retry logic.
65            // We were counting from one while `backon` counts from zero.
66            backoff.with_max_times(max_times.saturating_sub(1))
67        } else {
68            backoff
69        };
70
71        let retry_count = AtomicU64::new(1);
72
73        let send_request = || {
74            let send_progress = send_progress.clone();
75
76            async {
77                let num_attempt = retry_count.fetch_add(1, Ordering::SeqCst);
78                debug!(num_attempt, "Sending request");
79                let before = ruma::time::Instant::now();
80
81                let response =
82                    send_request(&self.inner, &request, config.timeout, send_progress).await?;
83
84                let request_duration = ruma::time::Instant::now().saturating_duration_since(before);
85
86                let status_code = response.status();
87                let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX));
88                tracing::Span::current()
89                    .record("status", status_code.as_u16())
90                    .record("response_size", response_size.display().si_short().to_string())
91                    .record("request_duration", tracing::field::debug(request_duration));
92
93                // Record interesting headers. If you add more headers, ensure they're not
94                // confidential.
95                for (header_name, header_value) in response.headers() {
96                    let header_name = header_name.as_str().to_lowercase();
97
98                    // Header added in case of OAuth 2.0 authentication failure, so we can correlate
99                    // failures with a Sentry event emitted by the OAuth 2.0 authentication server.
100                    if header_name == "x-sentry-event-id" {
101                        tracing::Span::current()
102                            .record("sentry_event_id", header_value.to_str().unwrap_or("<???>"));
103                    }
104                }
105
106                R::IncomingResponse::try_from_http_response(response).map_err(HttpError::from)
107            }
108        };
109
110        let has_retry_limit = config.retry_limit.is_some();
111
112        send_request
113            .retry(backoff)
114            .adjust(|err, default_timeout| {
115                match err.retry_kind() {
116                    RetryKind::Transient { retry_after } => {
117                        // This bit is somewhat tricky but it's necessary so we respect the
118                        // `max_times` limit from `backon`.
119                        //
120                        // The exponential backoff in `backon` is implemented as an iterator that
121                        // returns `None` when we hit the `max_times` limit. So it's necessary to
122                        // only override the `default_timeout` if it's `Some`.
123                        if default_timeout.is_some() {
124                            retry_after.or(default_timeout)
125                        } else {
126                            None
127                        }
128                    }
129                    RetryKind::Permanent => None,
130                    RetryKind::NetworkFailure => {
131                        // If we ran into a network failure, only retry if there's some retry limit
132                        // associated to this request's configuration; otherwise, we would end up
133                        // running an infinite loop of network requests in offline mode.
134                        if has_retry_limit { default_timeout } else { None }
135                    }
136                }
137            })
138            .await
139    }
140}
141
142#[cfg(not(target_family = "wasm"))]
143#[derive(Clone, Debug)]
144pub(crate) struct HttpSettings {
145    pub(crate) disable_ssl_verification: bool,
146    pub(crate) proxy: Option<String>,
147    pub(crate) user_agent: Option<String>,
148    pub(crate) timeout: Option<Duration>,
149    pub(crate) read_timeout: Option<Duration>,
150    pub(crate) additional_root_certificates: Vec<Certificate>,
151    pub(crate) disable_built_in_root_certificates: bool,
152}
153
154#[cfg(not(target_family = "wasm"))]
155impl Default for HttpSettings {
156    fn default() -> Self {
157        Self {
158            disable_ssl_verification: false,
159            proxy: None,
160            user_agent: None,
161            timeout: Some(DEFAULT_REQUEST_TIMEOUT),
162            read_timeout: None,
163            additional_root_certificates: Default::default(),
164            disable_built_in_root_certificates: false,
165        }
166    }
167}
168
169#[cfg(not(target_family = "wasm"))]
170impl HttpSettings {
171    /// Build a client with the specified configuration.
172    pub(crate) fn make_client(&self) -> Result<reqwest::Client, HttpError> {
173        let user_agent = self.user_agent.clone().unwrap_or_else(|| "matrix-rust-sdk".to_owned());
174        let mut http_client = reqwest::Client::builder()
175            .user_agent(user_agent)
176            // As recommended by BCP 195.
177            // See: https://datatracker.ietf.org/doc/bcp195/
178            .min_tls_version(tls::Version::TLS_1_2);
179
180        if let Some(timeout) = self.timeout {
181            http_client = http_client.timeout(timeout);
182        }
183
184        if let Some(read_timeout) = self.read_timeout {
185            http_client = http_client.read_timeout(read_timeout);
186        }
187
188        if self.disable_ssl_verification {
189            warn!("SSL verification disabled in the HTTP client!");
190            http_client = http_client.danger_accept_invalid_certs(true)
191        }
192
193        if !self.additional_root_certificates.is_empty() {
194            info!(
195                "Adding {} additional root certificates to the HTTP client",
196                self.additional_root_certificates.len()
197            );
198
199            for cert in &self.additional_root_certificates {
200                http_client = http_client.add_root_certificate(cert.clone());
201            }
202        }
203
204        if self.disable_built_in_root_certificates {
205            info!("Built-in root certificates disabled in the HTTP client.");
206            http_client = http_client.tls_built_in_root_certs(false);
207        }
208
209        if let Some(p) = &self.proxy {
210            info!(proxy_url = p, "Setting the proxy for the HTTP client");
211            http_client = http_client.proxy(reqwest::Proxy::all(p.as_str())?);
212        }
213
214        Ok(http_client.build()?)
215    }
216}
217
218pub(super) async fn send_request(
219    client: &reqwest::Client,
220    request: &http::Request<Bytes>,
221    timeout: Option<Duration>,
222    send_progress: SharedObservable<TransmissionProgress>,
223) -> Result<http::Response<Bytes>, HttpError> {
224    use std::convert::Infallible;
225
226    use futures_util::stream;
227
228    let request = request.clone();
229    let request = {
230        let mut request = if send_progress.subscriber_count() != 0 {
231            let content_length = request.body().len();
232            send_progress.update(|p| p.total += content_length);
233
234            // Make sure any concurrent futures in the same task get a chance
235            // to also add to the progress total before the first chunks are
236            // pulled out of the body stream.
237            tokio::task::yield_now().await;
238
239            let mut req = reqwest::Request::try_from(request.map(|body| {
240                let chunks = stream::iter(BytesChunks::new(body, 8192).map(
241                    move |chunk| -> Result<_, Infallible> {
242                        send_progress.update(|p| p.current += chunk.len());
243                        Ok(chunk)
244                    },
245                ));
246                reqwest::Body::wrap_stream(chunks)
247            }))?;
248
249            // When streaming the request, reqwest / hyper doesn't know how
250            // large the body is, so it doesn't set the content-length header
251            // (required by some servers). Set it manually.
252            req.headers_mut().insert(CONTENT_LENGTH, content_length.into());
253
254            req
255        } else {
256            reqwest::Request::try_from(request)?
257        };
258
259        *request.timeout_mut() = timeout;
260        request
261    };
262
263    let response = client.execute(request).await?;
264    Ok(response_to_http_response(response).await?)
265}
266
267struct BytesChunks {
268    bytes: Bytes,
269    size: usize,
270}
271
272impl BytesChunks {
273    fn new(bytes: Bytes, size: usize) -> Self {
274        assert_ne!(size, 0);
275        Self { bytes, size }
276    }
277}
278
279impl Iterator for BytesChunks {
280    type Item = Bytes;
281
282    fn next(&mut self) -> Option<Self::Item> {
283        if self.bytes.is_empty() {
284            None
285        } else if self.bytes.len() < self.size {
286            Some(mem::take(&mut self.bytes))
287        } else {
288            Some(self.bytes.split_to(self.size))
289        }
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use bytes::Bytes;
296
297    use super::BytesChunks;
298
299    #[test]
300    fn test_bytes_chunks() {
301        let bytes = Bytes::new();
302        assert!(BytesChunks::new(bytes, 1).collect::<Vec<_>>().is_empty());
303
304        let bytes = Bytes::from_iter([1, 2]);
305        assert_eq!(BytesChunks::new(bytes, 2).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
306
307        let bytes = Bytes::from_iter([1, 2]);
308        assert_eq!(BytesChunks::new(bytes, 3).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
309
310        let bytes = Bytes::from_iter([1, 2, 3]);
311        assert_eq!(
312            BytesChunks::new(bytes, 1).collect::<Vec<_>>(),
313            [Bytes::from_iter([1]), Bytes::from_iter([2]), Bytes::from_iter([3])]
314        );
315
316        let bytes = Bytes::from_iter([1, 2, 3]);
317        assert_eq!(
318            BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
319            [Bytes::from_iter([1, 2]), Bytes::from_iter([3])]
320        );
321
322        let bytes = Bytes::from_iter([1, 2, 3, 4]);
323        assert_eq!(
324            BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
325            [Bytes::from_iter([1, 2]), Bytes::from_iter([3, 4])]
326        );
327    }
328}