matrix_sdk/http_client/
native.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    fmt::Debug,
17    mem,
18    sync::atomic::{AtomicU64, Ordering},
19    time::Duration,
20};
21
22use backon::{ExponentialBuilder, Retryable};
23use bytes::Bytes;
24use bytesize::ByteSize;
25use eyeball::SharedObservable;
26use http::header::CONTENT_LENGTH;
27use reqwest::{tls, Certificate};
28use ruma::api::{error::FromHttpResponseError, IncomingResponse, OutgoingRequest};
29use tracing::{debug, info, warn};
30
31use super::{response_to_http_response, HttpClient, TransmissionProgress, DEFAULT_REQUEST_TIMEOUT};
32use crate::{
33    config::RequestConfig,
34    error::{HttpError, RetryKind},
35};
36
37impl HttpClient {
38    pub(super) async fn send_request<R>(
39        &self,
40        request: http::Request<Bytes>,
41        config: RequestConfig,
42        send_progress: SharedObservable<TransmissionProgress>,
43    ) -> Result<R::IncomingResponse, HttpError>
44    where
45        R: OutgoingRequest + Debug,
46        HttpError: From<FromHttpResponseError<R::EndpointError>>,
47    {
48        // These values were picked because we used to use the `backoff` crate, those
49        // were defined here: https://docs.rs/backoff/0.4.0/backoff/default/index.html
50        let backoff = ExponentialBuilder::new()
51            .with_min_delay(Duration::from_millis(500))
52            .with_max_delay(Duration::from_secs(60))
53            .with_total_delay(Some(Duration::from_secs(15 * 60)))
54            .without_max_times();
55
56        // Let's now apply any override the user or the SDK might have set.
57        let backoff = if let Some(max_delay) = config.max_retry_time {
58            backoff.with_max_delay(max_delay)
59        } else {
60            backoff
61        };
62
63        let backoff = if let Some(max_times) = config.retry_limit {
64            // Backon behaves a bit differently to our own handcrafted max retry logic.
65            // We were counting from one while `backon` counts from zero.
66            backoff.with_max_times(max_times.saturating_sub(1))
67        } else {
68            backoff
69        };
70
71        let retry_count = AtomicU64::new(1);
72
73        let send_request = || {
74            let send_progress = send_progress.clone();
75
76            async {
77                let num_attempt = retry_count.fetch_add(1, Ordering::SeqCst);
78                debug!(num_attempt, "Sending request");
79
80                let response =
81                    send_request(&self.inner, &request, config.timeout, send_progress).await?;
82
83                let status_code = response.status();
84                let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX));
85                tracing::Span::current()
86                    .record("status", status_code.as_u16())
87                    .record("response_size", response_size.display().si_short().to_string());
88
89                // Record interesting headers. If you add more headers, ensure they're not
90                // confidential.
91                for (header_name, header_value) in response.headers() {
92                    let header_name = header_name.as_str().to_lowercase();
93
94                    // Header added in case of OAuth 2.0 authentication failure, so we can correlate
95                    // failures with a Sentry event emitted by the OAuth 2.0 authentication server.
96                    if header_name == "x-sentry-event-id" {
97                        tracing::Span::current()
98                            .record("sentry_event_id", header_value.to_str().unwrap_or("<???>"));
99                    }
100                }
101
102                R::IncomingResponse::try_from_http_response(response).map_err(HttpError::from)
103            }
104        };
105
106        let has_retry_limit = config.retry_limit.is_some();
107
108        send_request
109            .retry(backoff)
110            .adjust(|err, default_timeout| {
111                match err.retry_kind() {
112                    RetryKind::Transient { retry_after } => {
113                        // This bit is somewhat tricky but it's necessary so we respect the
114                        // `max_times` limit from `backon`.
115                        //
116                        // The exponential backoff in `backon` is implemented as an iterator that
117                        // returns `None` when we hit the `max_times` limit. So it's necessary to
118                        // only override the `default_timeout` if it's `Some`.
119                        if default_timeout.is_some() {
120                            retry_after.or(default_timeout)
121                        } else {
122                            None
123                        }
124                    }
125                    RetryKind::Permanent => None,
126                    RetryKind::NetworkFailure => {
127                        // If we ran into a network failure, only retry if there's some retry limit
128                        // associated to this request's configuration; otherwise, we would end up
129                        // running an infinite loop of network requests in offline mode.
130                        if has_retry_limit {
131                            default_timeout
132                        } else {
133                            None
134                        }
135                    }
136                }
137            })
138            .await
139    }
140}
141
142#[cfg(not(target_family = "wasm"))]
143#[derive(Clone, Debug)]
144pub(crate) struct HttpSettings {
145    pub(crate) disable_ssl_verification: bool,
146    pub(crate) proxy: Option<String>,
147    pub(crate) user_agent: Option<String>,
148    pub(crate) timeout: Duration,
149    pub(crate) additional_root_certificates: Vec<Certificate>,
150    pub(crate) disable_built_in_root_certificates: bool,
151}
152
153#[cfg(not(target_family = "wasm"))]
154impl Default for HttpSettings {
155    fn default() -> Self {
156        Self {
157            disable_ssl_verification: false,
158            proxy: None,
159            user_agent: None,
160            timeout: DEFAULT_REQUEST_TIMEOUT,
161            additional_root_certificates: Default::default(),
162            disable_built_in_root_certificates: false,
163        }
164    }
165}
166
167#[cfg(not(target_family = "wasm"))]
168impl HttpSettings {
169    /// Build a client with the specified configuration.
170    pub(crate) fn make_client(&self) -> Result<reqwest::Client, HttpError> {
171        let user_agent = self.user_agent.clone().unwrap_or_else(|| "matrix-rust-sdk".to_owned());
172        let mut http_client = reqwest::Client::builder()
173            .user_agent(user_agent)
174            .timeout(self.timeout)
175            // As recommended by BCP 195.
176            // See: https://datatracker.ietf.org/doc/bcp195/
177            .min_tls_version(tls::Version::TLS_1_2);
178
179        if self.disable_ssl_verification {
180            warn!("SSL verification disabled in the HTTP client!");
181            http_client = http_client.danger_accept_invalid_certs(true)
182        }
183
184        if !self.additional_root_certificates.is_empty() {
185            info!(
186                "Adding {} additional root certificates to the HTTP client",
187                self.additional_root_certificates.len()
188            );
189
190            for cert in &self.additional_root_certificates {
191                http_client = http_client.add_root_certificate(cert.clone());
192            }
193        }
194
195        if self.disable_built_in_root_certificates {
196            info!("Built-in root certificates disabled in the HTTP client.");
197            http_client = http_client.tls_built_in_root_certs(false);
198        }
199
200        if let Some(p) = &self.proxy {
201            info!(proxy_url = p, "Setting the proxy for the HTTP client");
202            http_client = http_client.proxy(reqwest::Proxy::all(p.as_str())?);
203        }
204
205        Ok(http_client.build()?)
206    }
207}
208
209pub(super) async fn send_request(
210    client: &reqwest::Client,
211    request: &http::Request<Bytes>,
212    timeout: Duration,
213    send_progress: SharedObservable<TransmissionProgress>,
214) -> Result<http::Response<Bytes>, HttpError> {
215    use std::convert::Infallible;
216
217    use futures_util::stream;
218
219    let request = request.clone();
220    let request = {
221        let mut request = if send_progress.subscriber_count() != 0 {
222            let content_length = request.body().len();
223            send_progress.update(|p| p.total += content_length);
224
225            // Make sure any concurrent futures in the same task get a chance
226            // to also add to the progress total before the first chunks are
227            // pulled out of the body stream.
228            tokio::task::yield_now().await;
229
230            let mut req = reqwest::Request::try_from(request.map(|body| {
231                let chunks = stream::iter(BytesChunks::new(body, 8192).map(
232                    move |chunk| -> Result<_, Infallible> {
233                        send_progress.update(|p| p.current += chunk.len());
234                        Ok(chunk)
235                    },
236                ));
237                reqwest::Body::wrap_stream(chunks)
238            }))?;
239
240            // When streaming the request, reqwest / hyper doesn't know how
241            // large the body is, so it doesn't set the content-length header
242            // (required by some servers). Set it manually.
243            req.headers_mut().insert(CONTENT_LENGTH, content_length.into());
244
245            req
246        } else {
247            reqwest::Request::try_from(request)?
248        };
249
250        *request.timeout_mut() = Some(timeout);
251        request
252    };
253
254    let response = client.execute(request).await?;
255    Ok(response_to_http_response(response).await?)
256}
257
258struct BytesChunks {
259    bytes: Bytes,
260    size: usize,
261}
262
263impl BytesChunks {
264    fn new(bytes: Bytes, size: usize) -> Self {
265        assert_ne!(size, 0);
266        Self { bytes, size }
267    }
268}
269
270impl Iterator for BytesChunks {
271    type Item = Bytes;
272
273    fn next(&mut self) -> Option<Self::Item> {
274        if self.bytes.is_empty() {
275            None
276        } else if self.bytes.len() < self.size {
277            Some(mem::take(&mut self.bytes))
278        } else {
279            Some(self.bytes.split_to(self.size))
280        }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use bytes::Bytes;
287
288    use super::BytesChunks;
289
290    #[test]
291    fn test_bytes_chunks() {
292        let bytes = Bytes::new();
293        assert!(BytesChunks::new(bytes, 1).collect::<Vec<_>>().is_empty());
294
295        let bytes = Bytes::from_iter([1, 2]);
296        assert_eq!(BytesChunks::new(bytes, 2).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
297
298        let bytes = Bytes::from_iter([1, 2]);
299        assert_eq!(BytesChunks::new(bytes, 3).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
300
301        let bytes = Bytes::from_iter([1, 2, 3]);
302        assert_eq!(
303            BytesChunks::new(bytes, 1).collect::<Vec<_>>(),
304            [Bytes::from_iter([1]), Bytes::from_iter([2]), Bytes::from_iter([3])]
305        );
306
307        let bytes = Bytes::from_iter([1, 2, 3]);
308        assert_eq!(
309            BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
310            [Bytes::from_iter([1, 2]), Bytes::from_iter([3])]
311        );
312
313        let bytes = Bytes::from_iter([1, 2, 3, 4]);
314        assert_eq!(
315            BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
316            [Bytes::from_iter([1, 2]), Bytes::from_iter([3, 4])]
317        );
318    }
319}