1extern crate base64;
2extern crate md5;
3
4use bytes::Bytes;
5use futures_util::TryStreamExt;
6use maybe_async::maybe_async;
7use std::collections::HashMap;
8use std::str::FromStr as _;
9use time::OffsetDateTime;
10
11use super::request_trait::{Request, ResponseData, ResponseDataStream};
12use crate::bucket::Bucket;
13use crate::command::Command;
14use crate::command::HttpMethod;
15use crate::error::S3Error;
16use crate::retry;
17use crate::utils::now_utc;
18
19use tokio_stream::StreamExt;
20
21#[derive(Clone, Debug, Default)]
22pub(crate) struct ClientOptions {
23 pub request_timeout: Option<std::time::Duration>,
24 pub proxy: Option<reqwest::Proxy>,
25 #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))]
26 pub accept_invalid_certs: bool,
27 #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))]
28 pub accept_invalid_hostnames: bool,
29}
30
31#[cfg(feature = "with-tokio")]
32pub(crate) fn client(options: &ClientOptions) -> Result<reqwest::Client, S3Error> {
33 let client = reqwest::Client::builder();
34
35 let client = if let Some(timeout) = options.request_timeout {
36 client.timeout(timeout)
37 } else {
38 client
39 };
40
41 let client = if let Some(ref proxy) = options.proxy {
42 client.proxy(proxy.clone())
43 } else {
44 client
45 };
46
47 cfg_if::cfg_if! {
48 if #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))] {
49 let client = client.danger_accept_invalid_certs(options.accept_invalid_certs);
50 }
51 }
52
53 cfg_if::cfg_if! {
54 if #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))] {
55 let client = client.danger_accept_invalid_hostnames(options.accept_invalid_hostnames);
56 }
57 }
58
59 Ok(client.build()?)
60}
61pub struct ReqwestRequest<'a> {
63 pub bucket: &'a Bucket,
64 pub path: &'a str,
65 pub command: Command<'a>,
66 pub datetime: OffsetDateTime,
67 pub sync: bool,
68}
69
70#[maybe_async]
71impl<'a> Request for ReqwestRequest<'a> {
72 type Response = reqwest::Response;
73 type HeaderMap = reqwest::header::HeaderMap;
74
75 async fn response(&self) -> Result<Self::Response, S3Error> {
76 let headers = self
77 .headers()
78 .await?
79 .iter()
80 .map(|(k, v)| {
81 (
82 reqwest::header::HeaderName::from_str(k.as_str()),
83 reqwest::header::HeaderValue::from_str(v.to_str().unwrap_or_default()),
84 )
85 })
86 .filter(|(k, v)| k.is_ok() && v.is_ok())
87 .map(|(k, v)| (k.unwrap(), v.unwrap()))
88 .collect();
89
90 let client = self.bucket.http_client();
91
92 let method = match self.command.http_verb() {
93 HttpMethod::Delete => reqwest::Method::DELETE,
94 HttpMethod::Get => reqwest::Method::GET,
95 HttpMethod::Post => reqwest::Method::POST,
96 HttpMethod::Put => reqwest::Method::PUT,
97 HttpMethod::Head => reqwest::Method::HEAD,
98 };
99
100 let request = client
101 .request(method, self.url()?.as_str())
102 .headers(headers)
103 .body(self.request_body()?);
104
105 let request = request.build()?;
106
107 let response = client.execute(request).await?;
110
111 if cfg!(feature = "fail-on-err") && !response.status().is_success() {
112 let status = response.status().as_u16();
113 let text = response.text().await?;
114 return Err(S3Error::HttpFailWithBody(status, text));
115 }
116
117 Ok(response)
118 }
119
120 async fn response_status(&self) -> Result<u16, S3Error> {
121 retry! {
122 async {
123 let headers = self
124 .headers()
125 .await?
126 .iter()
127 .map(|(k, v)| {
128 (
129 reqwest::header::HeaderName::from_str(k.as_str()),
130 reqwest::header::HeaderValue::from_str(v.to_str().unwrap_or_default()),
131 )
132 })
133 .filter(|(k, v)| k.is_ok() && v.is_ok())
134 .map(|(k, v)| (k.unwrap(), v.unwrap()))
135 .collect();
136
137 let client = self.bucket.http_client();
138
139 let method = match self.command.http_verb() {
140 HttpMethod::Delete => reqwest::Method::DELETE,
141 HttpMethod::Get => reqwest::Method::GET,
142 HttpMethod::Post => reqwest::Method::POST,
143 HttpMethod::Put => reqwest::Method::PUT,
144 HttpMethod::Head => reqwest::Method::HEAD,
145 };
146
147 let request = client
148 .request(method, self.url()?.as_str())
149 .headers(headers)
150 .body(self.request_body()?);
151
152 let request = request.build()?;
153 let response = client.execute(request).await?;
154 let status = response.status().as_u16();
155
156 if status == 404 {
157 return Ok(status);
158 }
159
160 if cfg!(feature = "fail-on-err") && !response.status().is_success() {
161 let text = response.text().await?;
162 return Err(S3Error::HttpFailWithBody(status, text));
163 }
164
165 Ok(status)
166 }.await
167 }
168 }
169
170 async fn response_data(&self, etag: bool) -> Result<ResponseData, S3Error> {
171 let response = retry! {self.response().await }?;
172 let status_code = response.status().as_u16();
173 let mut headers = response.headers().clone();
174 let response_headers = headers
175 .clone()
176 .iter()
177 .map(|(k, v)| {
178 (
179 k.to_string(),
180 v.to_str()
181 .unwrap_or("could-not-decode-header-value")
182 .to_string(),
183 )
184 })
185 .collect::<HashMap<String, String>>();
186 let body_vec = if etag {
199 if let Some(etag) = headers.remove("ETag") {
200 Bytes::from(etag.to_str()?.to_string())
201 } else {
202 Bytes::from("")
203 }
204 } else {
205 response.bytes().await?
206 };
207 Ok(ResponseData::new(body_vec, status_code, response_headers))
208 }
209
210 async fn response_data_to_writer<T: tokio::io::AsyncWrite + Send + Unpin + ?Sized>(
211 &self,
212 writer: &mut T,
213 ) -> Result<u16, S3Error> {
214 use tokio::io::AsyncWriteExt;
215 let response = retry! {self.response().await}?;
216
217 let status_code = response.status();
218 let mut stream = response.bytes_stream();
219
220 while let Some(item) = stream.next().await {
221 writer.write_all(&item?).await?;
222 }
223
224 Ok(status_code.as_u16())
225 }
226
227 async fn response_data_to_stream(&self) -> Result<ResponseDataStream, S3Error> {
228 let response = retry! {self.response().await}?;
229 let status_code = response.status();
230 let stream = response.bytes_stream().map_err(S3Error::Reqwest);
231
232 Ok(ResponseDataStream {
233 bytes: Box::pin(stream),
234 status_code: status_code.as_u16(),
235 })
236 }
237
238 async fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error> {
239 let response = retry! {self.response().await}?;
240 let status_code = response.status().as_u16();
241 let headers = response.headers().clone();
242 Ok((headers, status_code))
243 }
244
245 fn datetime(&self) -> OffsetDateTime {
246 self.datetime
247 }
248
249 fn bucket(&self) -> Bucket {
250 self.bucket.clone()
251 }
252
253 fn command(&self) -> Command<'_> {
254 self.command.clone()
255 }
256
257 fn path(&self) -> String {
258 self.path.to_string()
259 }
260}
261
262impl<'a> ReqwestRequest<'a> {
263 pub async fn new(
264 bucket: &'a Bucket,
265 path: &'a str,
266 command: Command<'a>,
267 ) -> Result<ReqwestRequest<'a>, S3Error> {
268 bucket.credentials_refresh().await?;
269 Ok(Self {
270 bucket,
271 path,
272 command,
273 datetime: now_utc(),
274 sync: false,
275 })
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use crate::bucket::Bucket;
282 use crate::command::Command;
283 use crate::request::Request;
284 use crate::request::tokio_backend::ReqwestRequest;
285 use awscreds::Credentials;
286 use http::header::{HOST, RANGE};
287
288 fn fake_credentials() -> Credentials {
291 let access_key = "AKIAIOSFODNN7EXAMPLE";
292 let secert_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY";
293 Credentials::new(Some(access_key), Some(secert_key), None, None, None).unwrap()
294 }
295
296 #[tokio::test]
297 async fn url_uses_https_by_default() {
298 let region = "custom-region".parse().unwrap();
299 let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap();
300 let path = "/my-first/path";
301 let request = ReqwestRequest::new(&bucket, path, Command::GetObject)
302 .await
303 .unwrap();
304
305 assert_eq!(request.url().unwrap().scheme(), "https");
306
307 let headers = request.headers().await.unwrap();
308 let host = headers.get(HOST).unwrap();
309
310 assert_eq!(*host, "my-first-bucket.custom-region".to_string());
311 }
312
313 #[tokio::test]
314 async fn url_uses_https_by_default_path_style() {
315 let region = "custom-region".parse().unwrap();
316 let bucket = Bucket::new("my-first-bucket", region, fake_credentials())
317 .unwrap()
318 .with_path_style();
319 let path = "/my-first/path";
320 let request = ReqwestRequest::new(&bucket, path, Command::GetObject)
321 .await
322 .unwrap();
323
324 assert_eq!(request.url().unwrap().scheme(), "https");
325
326 let headers = request.headers().await.unwrap();
327 let host = headers.get(HOST).unwrap();
328
329 assert_eq!(*host, "custom-region".to_string());
330 }
331
332 #[tokio::test]
333 async fn url_uses_scheme_from_custom_region_if_defined() {
334 let region = "http://custom-region".parse().unwrap();
335 let bucket = Bucket::new("my-second-bucket", region, fake_credentials()).unwrap();
336 let path = "/my-second/path";
337 let request = ReqwestRequest::new(&bucket, path, Command::GetObject)
338 .await
339 .unwrap();
340
341 assert_eq!(request.url().unwrap().scheme(), "http");
342
343 let headers = request.headers().await.unwrap();
344 let host = headers.get(HOST).unwrap();
345 assert_eq!(*host, "my-second-bucket.custom-region".to_string());
346 }
347
348 #[tokio::test]
349 async fn url_uses_scheme_from_custom_region_if_defined_with_path_style() {
350 let region = "http://custom-region".parse().unwrap();
351 let bucket = Bucket::new("my-second-bucket", region, fake_credentials())
352 .unwrap()
353 .with_path_style();
354 let path = "/my-second/path";
355 let request = ReqwestRequest::new(&bucket, path, Command::GetObject)
356 .await
357 .unwrap();
358
359 assert_eq!(request.url().unwrap().scheme(), "http");
360
361 let headers = request.headers().await.unwrap();
362 let host = headers.get(HOST).unwrap();
363 assert_eq!(*host, "custom-region".to_string());
364 }
365
366 #[tokio::test]
367 async fn test_get_object_range_header() {
368 let region = "http://custom-region".parse().unwrap();
369 let bucket = Bucket::new("my-second-bucket", region, fake_credentials())
370 .unwrap()
371 .with_path_style();
372 let path = "/my-second/path";
373
374 let request = ReqwestRequest::new(
375 &bucket,
376 path,
377 Command::GetObjectRange {
378 start: 0,
379 end: None,
380 },
381 )
382 .await
383 .unwrap();
384 let headers = request.headers().await.unwrap();
385 let range = headers.get(RANGE).unwrap();
386 assert_eq!(range, "bytes=0-");
387
388 let request = ReqwestRequest::new(
389 &bucket,
390 path,
391 Command::GetObjectRange {
392 start: 0,
393 end: Some(1),
394 },
395 )
396 .await
397 .unwrap();
398 let headers = request.headers().await.unwrap();
399 let range = headers.get(RANGE).unwrap();
400 assert_eq!(range, "bytes=0-1");
401 }
402}