Skip to main content

s3/request/
tokio_backend.rs

1extern crate base64;
2extern crate md5;
3
4use bytes::Bytes;
5use futures_util::TryStreamExt;
6use maybe_async::maybe_async;
7use std::collections::HashMap;
8use std::str::FromStr as _;
9use time::OffsetDateTime;
10
11use super::request_trait::{Request, ResponseData, ResponseDataStream};
12use crate::bucket::Bucket;
13use crate::command::Command;
14use crate::command::HttpMethod;
15use crate::error::S3Error;
16use crate::retry;
17use crate::utils::now_utc;
18
19use tokio_stream::StreamExt;
20
21#[derive(Clone, Debug, Default)]
22pub(crate) struct ClientOptions {
23    pub request_timeout: Option<std::time::Duration>,
24    pub proxy: Option<reqwest::Proxy>,
25    #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))]
26    pub accept_invalid_certs: bool,
27    #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))]
28    pub accept_invalid_hostnames: bool,
29}
30
31#[cfg(feature = "with-tokio")]
32pub(crate) fn client(options: &ClientOptions) -> Result<reqwest::Client, S3Error> {
33    let client = reqwest::Client::builder();
34
35    let client = if let Some(timeout) = options.request_timeout {
36        client.timeout(timeout)
37    } else {
38        client
39    };
40
41    let client = if let Some(ref proxy) = options.proxy {
42        client.proxy(proxy.clone())
43    } else {
44        client
45    };
46
47    cfg_if::cfg_if! {
48        if #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))] {
49            let client = client.danger_accept_invalid_certs(options.accept_invalid_certs);
50        }
51    }
52
53    cfg_if::cfg_if! {
54        if #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))] {
55            let client = client.danger_accept_invalid_hostnames(options.accept_invalid_hostnames);
56        }
57    }
58
59    Ok(client.build()?)
60}
61// Temporary structure for making a request
62pub struct ReqwestRequest<'a> {
63    pub bucket: &'a Bucket,
64    pub path: &'a str,
65    pub command: Command<'a>,
66    pub datetime: OffsetDateTime,
67    pub sync: bool,
68}
69
70#[maybe_async]
71impl<'a> Request for ReqwestRequest<'a> {
72    type Response = reqwest::Response;
73    type HeaderMap = reqwest::header::HeaderMap;
74
75    async fn response(&self) -> Result<Self::Response, S3Error> {
76        let headers = self
77            .headers()
78            .await?
79            .iter()
80            .map(|(k, v)| {
81                (
82                    reqwest::header::HeaderName::from_str(k.as_str()),
83                    reqwest::header::HeaderValue::from_str(v.to_str().unwrap_or_default()),
84                )
85            })
86            .filter(|(k, v)| k.is_ok() && v.is_ok())
87            .map(|(k, v)| (k.unwrap(), v.unwrap()))
88            .collect();
89
90        let client = self.bucket.http_client();
91
92        let method = match self.command.http_verb() {
93            HttpMethod::Delete => reqwest::Method::DELETE,
94            HttpMethod::Get => reqwest::Method::GET,
95            HttpMethod::Post => reqwest::Method::POST,
96            HttpMethod::Put => reqwest::Method::PUT,
97            HttpMethod::Head => reqwest::Method::HEAD,
98        };
99
100        let request = client
101            .request(method, self.url()?.as_str())
102            .headers(headers)
103            .body(self.request_body()?);
104
105        let request = request.build()?;
106
107        // println!("Request: {:?}", request);
108
109        let response = client.execute(request).await?;
110
111        if cfg!(feature = "fail-on-err") && !response.status().is_success() {
112            let status = response.status().as_u16();
113            let text = response.text().await?;
114            return Err(S3Error::HttpFailWithBody(status, text));
115        }
116
117        Ok(response)
118    }
119
120    async fn response_status(&self) -> Result<u16, S3Error> {
121        retry! {
122            async {
123                let headers = self
124                    .headers()
125                    .await?
126                    .iter()
127                    .map(|(k, v)| {
128                        (
129                            reqwest::header::HeaderName::from_str(k.as_str()),
130                            reqwest::header::HeaderValue::from_str(v.to_str().unwrap_or_default()),
131                        )
132                    })
133                    .filter(|(k, v)| k.is_ok() && v.is_ok())
134                    .map(|(k, v)| (k.unwrap(), v.unwrap()))
135                    .collect();
136
137                let client = self.bucket.http_client();
138
139                let method = match self.command.http_verb() {
140                    HttpMethod::Delete => reqwest::Method::DELETE,
141                    HttpMethod::Get => reqwest::Method::GET,
142                    HttpMethod::Post => reqwest::Method::POST,
143                    HttpMethod::Put => reqwest::Method::PUT,
144                    HttpMethod::Head => reqwest::Method::HEAD,
145                };
146
147                let request = client
148                    .request(method, self.url()?.as_str())
149                    .headers(headers)
150                    .body(self.request_body()?);
151
152                let request = request.build()?;
153                let response = client.execute(request).await?;
154                let status = response.status().as_u16();
155
156                if status == 404 {
157                    return Ok(status);
158                }
159
160                if cfg!(feature = "fail-on-err") && !response.status().is_success() {
161                    let text = response.text().await?;
162                    return Err(S3Error::HttpFailWithBody(status, text));
163                }
164
165                Ok(status)
166            }.await
167        }
168    }
169
170    async fn response_data(&self, etag: bool) -> Result<ResponseData, S3Error> {
171        let response = retry! {self.response().await }?;
172        let status_code = response.status().as_u16();
173        let mut headers = response.headers().clone();
174        let response_headers = headers
175            .clone()
176            .iter()
177            .map(|(k, v)| {
178                (
179                    k.to_string(),
180                    v.to_str()
181                        .unwrap_or("could-not-decode-header-value")
182                        .to_string(),
183                )
184            })
185            .collect::<HashMap<String, String>>();
186        // When etag=true, we extract the ETag header and return it as the body.
187        // This is used for PUT operations (regular puts, multipart chunks) where:
188        // 1. S3 returns an empty or non-useful response body
189        // 2. The ETag header contains the essential information we need
190        // 3. The calling code expects to get the ETag via response_data.as_str()
191        //
192        // Note: This approach means we discard any actual response body when etag=true,
193        // but for the operations that use this (PUTs), the body is typically empty
194        // or contains redundant information already available in headers.
195        //
196        // TODO: Refactor this to properly return the response body and access ETag
197        // from headers instead of replacing the body. This would be a breaking change.
198        let body_vec = if etag {
199            if let Some(etag) = headers.remove("ETag") {
200                Bytes::from(etag.to_str()?.to_string())
201            } else {
202                Bytes::from("")
203            }
204        } else {
205            response.bytes().await?
206        };
207        Ok(ResponseData::new(body_vec, status_code, response_headers))
208    }
209
210    async fn response_data_to_writer<T: tokio::io::AsyncWrite + Send + Unpin + ?Sized>(
211        &self,
212        writer: &mut T,
213    ) -> Result<u16, S3Error> {
214        use tokio::io::AsyncWriteExt;
215        let response = retry! {self.response().await}?;
216
217        let status_code = response.status();
218        let mut stream = response.bytes_stream();
219
220        while let Some(item) = stream.next().await {
221            writer.write_all(&item?).await?;
222        }
223
224        Ok(status_code.as_u16())
225    }
226
227    async fn response_data_to_stream(&self) -> Result<ResponseDataStream, S3Error> {
228        let response = retry! {self.response().await}?;
229        let status_code = response.status();
230        let stream = response.bytes_stream().map_err(S3Error::Reqwest);
231
232        Ok(ResponseDataStream {
233            bytes: Box::pin(stream),
234            status_code: status_code.as_u16(),
235        })
236    }
237
238    async fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error> {
239        let response = retry! {self.response().await}?;
240        let status_code = response.status().as_u16();
241        let headers = response.headers().clone();
242        Ok((headers, status_code))
243    }
244
245    fn datetime(&self) -> OffsetDateTime {
246        self.datetime
247    }
248
249    fn bucket(&self) -> Bucket {
250        self.bucket.clone()
251    }
252
253    fn command(&self) -> Command<'_> {
254        self.command.clone()
255    }
256
257    fn path(&self) -> String {
258        self.path.to_string()
259    }
260}
261
262impl<'a> ReqwestRequest<'a> {
263    pub async fn new(
264        bucket: &'a Bucket,
265        path: &'a str,
266        command: Command<'a>,
267    ) -> Result<ReqwestRequest<'a>, S3Error> {
268        bucket.credentials_refresh().await?;
269        Ok(Self {
270            bucket,
271            path,
272            command,
273            datetime: now_utc(),
274            sync: false,
275        })
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use crate::bucket::Bucket;
282    use crate::command::Command;
283    use crate::request::Request;
284    use crate::request::tokio_backend::ReqwestRequest;
285    use awscreds::Credentials;
286    use http::header::{HOST, RANGE};
287
288    // Fake keys - otherwise using Credentials::default will use actual user
289    // credentials if they exist.
290    fn fake_credentials() -> Credentials {
291        let access_key = "AKIAIOSFODNN7EXAMPLE";
292        let secert_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY";
293        Credentials::new(Some(access_key), Some(secert_key), None, None, None).unwrap()
294    }
295
296    #[tokio::test]
297    async fn url_uses_https_by_default() {
298        let region = "custom-region".parse().unwrap();
299        let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap();
300        let path = "/my-first/path";
301        let request = ReqwestRequest::new(&bucket, path, Command::GetObject)
302            .await
303            .unwrap();
304
305        assert_eq!(request.url().unwrap().scheme(), "https");
306
307        let headers = request.headers().await.unwrap();
308        let host = headers.get(HOST).unwrap();
309
310        assert_eq!(*host, "my-first-bucket.custom-region".to_string());
311    }
312
313    #[tokio::test]
314    async fn url_uses_https_by_default_path_style() {
315        let region = "custom-region".parse().unwrap();
316        let bucket = Bucket::new("my-first-bucket", region, fake_credentials())
317            .unwrap()
318            .with_path_style();
319        let path = "/my-first/path";
320        let request = ReqwestRequest::new(&bucket, path, Command::GetObject)
321            .await
322            .unwrap();
323
324        assert_eq!(request.url().unwrap().scheme(), "https");
325
326        let headers = request.headers().await.unwrap();
327        let host = headers.get(HOST).unwrap();
328
329        assert_eq!(*host, "custom-region".to_string());
330    }
331
332    #[tokio::test]
333    async fn url_uses_scheme_from_custom_region_if_defined() {
334        let region = "http://custom-region".parse().unwrap();
335        let bucket = Bucket::new("my-second-bucket", region, fake_credentials()).unwrap();
336        let path = "/my-second/path";
337        let request = ReqwestRequest::new(&bucket, path, Command::GetObject)
338            .await
339            .unwrap();
340
341        assert_eq!(request.url().unwrap().scheme(), "http");
342
343        let headers = request.headers().await.unwrap();
344        let host = headers.get(HOST).unwrap();
345        assert_eq!(*host, "my-second-bucket.custom-region".to_string());
346    }
347
348    #[tokio::test]
349    async fn url_uses_scheme_from_custom_region_if_defined_with_path_style() {
350        let region = "http://custom-region".parse().unwrap();
351        let bucket = Bucket::new("my-second-bucket", region, fake_credentials())
352            .unwrap()
353            .with_path_style();
354        let path = "/my-second/path";
355        let request = ReqwestRequest::new(&bucket, path, Command::GetObject)
356            .await
357            .unwrap();
358
359        assert_eq!(request.url().unwrap().scheme(), "http");
360
361        let headers = request.headers().await.unwrap();
362        let host = headers.get(HOST).unwrap();
363        assert_eq!(*host, "custom-region".to_string());
364    }
365
366    #[tokio::test]
367    async fn test_get_object_range_header() {
368        let region = "http://custom-region".parse().unwrap();
369        let bucket = Bucket::new("my-second-bucket", region, fake_credentials())
370            .unwrap()
371            .with_path_style();
372        let path = "/my-second/path";
373
374        let request = ReqwestRequest::new(
375            &bucket,
376            path,
377            Command::GetObjectRange {
378                start: 0,
379                end: None,
380            },
381        )
382        .await
383        .unwrap();
384        let headers = request.headers().await.unwrap();
385        let range = headers.get(RANGE).unwrap();
386        assert_eq!(range, "bytes=0-");
387
388        let request = ReqwestRequest::new(
389            &bucket,
390            path,
391            Command::GetObjectRange {
392                start: 0,
393                end: Some(1),
394            },
395        )
396        .await
397        .unwrap();
398        let headers = request.headers().await.unwrap();
399        let range = headers.get(RANGE).unwrap();
400        assert_eq!(range, "bytes=0-1");
401    }
402}