Skip to main content

matrix_sdk/http_client/
native.rs

1// Copyright 2023 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#[cfg(target_os = "android")]
16use std::sync::Arc;
17use std::{
18    fmt::Debug,
19    mem,
20    sync::atomic::{AtomicU64, Ordering},
21    time::Duration,
22};
23
24use backon::{ExponentialBuilder, Retryable};
25use bytes::Bytes;
26use bytesize::ByteSize;
27use eyeball::SharedObservable;
28use http::header::CONTENT_LENGTH;
29#[cfg(not(target_family = "wasm"))]
30use reqwest::Certificate;
31#[cfg(target_os = "android")]
32use reqwest::ClientBuilder;
33use reqwest::tls;
34use ruma::api::{IncomingResponse, OutgoingRequest, error::FromHttpResponseError};
35#[cfg(target_os = "android")]
36use rustls::{RootCertStore, client::WebPkiServerVerifier};
37#[cfg(target_os = "android")]
38use rustls_pki_types::CertificateDer;
39use tracing::{debug, info, warn};
40
41use super::{DEFAULT_REQUEST_TIMEOUT, HttpClient, TransmissionProgress, response_to_http_response};
42use crate::{
43    config::RequestConfig,
44    error::{HttpError, RetryKind},
45};
46
47impl HttpClient {
48    pub(super) async fn send_request<R>(
49        &self,
50        request: http::Request<Bytes>,
51        config: RequestConfig,
52        send_progress: SharedObservable<TransmissionProgress>,
53    ) -> Result<R::IncomingResponse, HttpError>
54    where
55        R: OutgoingRequest + Debug,
56        HttpError: From<FromHttpResponseError<R::EndpointError>>,
57    {
58        // These values were picked because we used to use the `backoff` crate, those
59        // were defined here: https://docs.rs/backoff/0.4.0/backoff/default/index.html
60        let backoff = ExponentialBuilder::new()
61            .with_min_delay(Duration::from_millis(500))
62            .with_max_delay(Duration::from_secs(60))
63            .with_total_delay(Some(Duration::from_secs(15 * 60)))
64            .without_max_times();
65
66        // Let's now apply any override the user or the SDK might have set.
67        let backoff = if let Some(max_delay) = config.max_retry_time {
68            backoff.with_max_delay(max_delay)
69        } else {
70            backoff
71        };
72
73        let backoff = if let Some(max_times) = config.retry_limit {
74            // Backon behaves a bit differently to our own handcrafted max retry logic.
75            // We were counting from one while `backon` counts from zero.
76            backoff.with_max_times(max_times.saturating_sub(1))
77        } else {
78            backoff
79        };
80
81        let retry_count = AtomicU64::new(1);
82
83        let send_request = || {
84            let send_progress = send_progress.clone();
85
86            async {
87                let num_attempt = retry_count.fetch_add(1, Ordering::SeqCst);
88                debug!(num_attempt, "Sending request");
89                let before = ruma::time::Instant::now();
90
91                let response =
92                    send_request(&self.inner, &request, config.timeout, send_progress).await?;
93
94                let request_duration = ruma::time::Instant::now().saturating_duration_since(before);
95
96                let status_code = response.status();
97                let response_size = ByteSize(response.body().len().try_into().unwrap_or(u64::MAX));
98                tracing::Span::current()
99                    .record("status", status_code.as_u16())
100                    .record("response_size", response_size.display().si_short().to_string())
101                    .record("request_duration", tracing::field::debug(request_duration));
102
103                // Record interesting headers. If you add more headers, ensure they're not
104                // confidential.
105                for (header_name, header_value) in response.headers() {
106                    let header_name = header_name.as_str().to_lowercase();
107
108                    // Header added in case of OAuth 2.0 authentication failure, so we can correlate
109                    // failures with a Sentry event emitted by the OAuth 2.0 authentication server.
110                    if header_name == "x-sentry-event-id" {
111                        tracing::Span::current()
112                            .record("sentry_event_id", header_value.to_str().unwrap_or("<???>"));
113                    }
114                }
115
116                R::IncomingResponse::try_from_http_response(response).map_err(HttpError::from)
117            }
118        };
119
120        let has_retry_limit = config.retry_limit.is_some();
121
122        send_request
123            .retry(backoff)
124            .adjust(|err, backon_suggested_timeout| {
125                match err.retry_kind() {
126                    RetryKind::Transient { retry_after } => {
127                        // This bit is somewhat tricky but it's necessary so we respect the
128                        // `max_times` limit from `backon`.
129                        //
130                        // The exponential backoff in `backon` is implemented as an iterator that
131                        // returns `None` when we hit the `max_times` limit; if it returned `None`,
132                        // that means we ran out of attempts. So it's necessary to only override
133                        // the `backon_suggested_timeout` if it's `Some`.
134                        if backon_suggested_timeout.is_some() {
135                            retry_after.or(backon_suggested_timeout)
136                        } else {
137                            None
138                        }
139                    }
140                    RetryKind::Permanent => None,
141                    RetryKind::NetworkFailure => {
142                        // If we ran into a network failure, only retry if there's some retry limit
143                        // associated to this request's configuration; otherwise, we would end up
144                        // running an infinite loop of network requests in offline mode.
145                        if has_retry_limit { backon_suggested_timeout } else { None }
146                    }
147                }
148            })
149            .await
150    }
151}
152
153#[cfg(not(target_family = "wasm"))]
154#[derive(Clone, Debug)]
155pub(crate) struct HttpSettings {
156    pub(crate) disable_ssl_verification: bool,
157    pub(crate) proxy: Option<String>,
158    pub(crate) user_agent: Option<String>,
159    pub(crate) timeout: Option<Duration>,
160    pub(crate) read_timeout: Option<Duration>,
161    pub(crate) additional_root_certificates: Vec<Certificate>,
162    #[cfg(target_os = "android")]
163    pub(crate) additional_raw_root_certificates: Vec<Vec<u8>>,
164    pub(crate) disable_built_in_root_certificates: bool,
165}
166
167#[cfg(not(target_family = "wasm"))]
168impl Default for HttpSettings {
169    fn default() -> Self {
170        Self {
171            disable_ssl_verification: false,
172            proxy: None,
173            user_agent: None,
174            timeout: Some(DEFAULT_REQUEST_TIMEOUT),
175            read_timeout: None,
176            additional_root_certificates: Default::default(),
177            #[cfg(target_os = "android")]
178            additional_raw_root_certificates: Default::default(),
179            disable_built_in_root_certificates: false,
180        }
181    }
182}
183
184#[cfg(not(target_family = "wasm"))]
185impl HttpSettings {
186    /// Build a client with the specified configuration.
187    pub(crate) fn make_client(&self) -> Result<reqwest::Client, HttpError> {
188        let user_agent = self.user_agent.clone().unwrap_or_else(|| "matrix-rust-sdk".to_owned());
189        let mut http_client = reqwest::Client::builder()
190            .user_agent(user_agent)
191            // As recommended by BCP 195.
192            // See: https://datatracker.ietf.org/doc/bcp195/
193            .min_tls_version(tls::Version::TLS_1_2);
194
195        if let Some(timeout) = self.timeout {
196            http_client = http_client.timeout(timeout);
197        }
198
199        if let Some(read_timeout) = self.read_timeout {
200            http_client = http_client.read_timeout(read_timeout);
201        }
202
203        // On Android there is a problem that causes some certificates to be incorrectly
204        // marked as revoked, so we build our own rustls instance with the right
205        // configuration.
206        // Remove when https://github.com/rustls/rustls-platform-verifier/issues/221 is fixed.
207        #[cfg(target_os = "android")]
208        {
209            http_client = self.android_setup_webkpi_verifier(http_client)?;
210        }
211
212        if self.disable_ssl_verification {
213            warn!("SSL verification disabled in the HTTP client!");
214            http_client = http_client.danger_accept_invalid_certs(true);
215        }
216
217        http_client = if self.disable_built_in_root_certificates {
218            info!("Built-in root certificates disabled in the HTTP client.");
219            http_client.tls_certs_only(self.additional_root_certificates.clone())
220        } else {
221            http_client.tls_certs_merge(self.additional_root_certificates.clone())
222        };
223
224        if let Some(p) = &self.proxy {
225            info!(proxy_url = p, "Setting the proxy for the HTTP client");
226            http_client = http_client.proxy(reqwest::Proxy::all(p.as_str())?);
227        }
228
229        Ok(http_client.build()?)
230    }
231
232    #[cfg(target_os = "android")]
233    fn android_setup_webkpi_verifier(
234        &self,
235        client_builder: ClientBuilder,
236    ) -> Result<ClientBuilder, HttpError> {
237        if !self.disable_ssl_verification {
238            let mut root_store = RootCertStore::empty();
239
240            if self.disable_built_in_root_certificates {
241                info!("Built-in root certificates disabled in the HTTP client.");
242            } else {
243                // This seems to fix the 'revoked certificate' false positives issue
244                root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
245
246                // Also load the native certs
247                let native_certs = rustls_native_certs::load_native_certs().certs;
248                root_store.add_parsable_certificates(native_certs);
249            }
250
251            if !self.additional_raw_root_certificates.is_empty() {
252                let mut additional_certs = Vec::new();
253
254                warn!(
255                    "Adding {} extra user certificates",
256                    self.additional_raw_root_certificates.len()
257                );
258
259                for certificate in self.additional_raw_root_certificates.iter() {
260                    additional_certs.push(CertificateDer::from_slice(certificate));
261                }
262
263                root_store.add_parsable_certificates(additional_certs);
264            }
265
266            let verifier = WebPkiServerVerifier::builder(Arc::new(root_store))
267                .build()
268                .map_err(HttpError::VerifierBuilder)?;
269
270            let config = rustls::ClientConfig::builder()
271                .with_webpki_verifier(verifier)
272                .with_no_client_auth();
273            Ok(client_builder.tls_backend_preconfigured(config))
274        } else {
275            Ok(client_builder)
276        }
277    }
278}
279
280pub(super) async fn send_request(
281    client: &reqwest::Client,
282    request: &http::Request<Bytes>,
283    timeout: Option<Duration>,
284    send_progress: SharedObservable<TransmissionProgress>,
285) -> Result<http::Response<Bytes>, HttpError> {
286    use std::convert::Infallible;
287
288    use futures_util::stream;
289
290    let request = request.clone();
291    let request = {
292        let mut request = if send_progress.subscriber_count() != 0 {
293            let content_length = request.body().len();
294            send_progress.update(|p| p.total += content_length);
295
296            // Make sure any concurrent futures in the same task get a chance
297            // to also add to the progress total before the first chunks are
298            // pulled out of the body stream.
299            tokio::task::yield_now().await;
300
301            let mut req = reqwest::Request::try_from(request.map(|body| {
302                let chunks = stream::iter(BytesChunks::new(body, 8192).map(
303                    move |chunk| -> Result<_, Infallible> {
304                        send_progress.update(|p| p.current += chunk.len());
305                        Ok(chunk)
306                    },
307                ));
308                reqwest::Body::wrap_stream(chunks)
309            }))?;
310
311            // When streaming the request, reqwest / hyper doesn't know how
312            // large the body is, so it doesn't set the content-length header
313            // (required by some servers). Set it manually.
314            req.headers_mut().insert(CONTENT_LENGTH, content_length.into());
315
316            req
317        } else {
318            reqwest::Request::try_from(request)?
319        };
320
321        *request.timeout_mut() = timeout;
322        request
323    };
324
325    let response = client.execute(request).await?;
326    Ok(response_to_http_response(response).await?)
327}
328
329struct BytesChunks {
330    bytes: Bytes,
331    size: usize,
332}
333
334impl BytesChunks {
335    fn new(bytes: Bytes, size: usize) -> Self {
336        assert_ne!(size, 0);
337        Self { bytes, size }
338    }
339}
340
341impl Iterator for BytesChunks {
342    type Item = Bytes;
343
344    fn next(&mut self) -> Option<Self::Item> {
345        if self.bytes.is_empty() {
346            None
347        } else if self.bytes.len() < self.size {
348            Some(mem::take(&mut self.bytes))
349        } else {
350            Some(self.bytes.split_to(self.size))
351        }
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use bytes::Bytes;
358
359    use super::BytesChunks;
360
361    #[test]
362    fn test_bytes_chunks() {
363        let bytes = Bytes::new();
364        assert!(BytesChunks::new(bytes, 1).collect::<Vec<_>>().is_empty());
365
366        let bytes = Bytes::from_iter([1, 2]);
367        assert_eq!(BytesChunks::new(bytes, 2).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
368
369        let bytes = Bytes::from_iter([1, 2]);
370        assert_eq!(BytesChunks::new(bytes, 3).collect::<Vec<_>>(), [Bytes::from_iter([1, 2])]);
371
372        let bytes = Bytes::from_iter([1, 2, 3]);
373        assert_eq!(
374            BytesChunks::new(bytes, 1).collect::<Vec<_>>(),
375            [Bytes::from_iter([1]), Bytes::from_iter([2]), Bytes::from_iter([3])]
376        );
377
378        let bytes = Bytes::from_iter([1, 2, 3]);
379        assert_eq!(
380            BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
381            [Bytes::from_iter([1, 2]), Bytes::from_iter([3])]
382        );
383
384        let bytes = Bytes::from_iter([1, 2, 3, 4]);
385        assert_eq!(
386            BytesChunks::new(bytes, 2).collect::<Vec<_>>(),
387            [Bytes::from_iter([1, 2]), Bytes::from_iter([3, 4])]
388        );
389    }
390}