s3/request/
tokio_backend.rs

1extern crate base64;
2extern crate md5;
3
4use bytes::Bytes;
5use maybe_async::maybe_async;
6use reqwest::{Client, Response};
7use std::collections::HashMap;
8use time::OffsetDateTime;
9
10use super::request_trait::{Request, ResponseData, ResponseDataStream};
11use crate::bucket::Bucket;
12use crate::command::Command;
13use crate::command::HttpMethod;
14use crate::error::S3Error;
15
16use tokio_stream::StreamExt;
17
18// Temporary structure for making a request
19pub struct Reqwest<'a> {
20    pub bucket: &'a Bucket,
21    pub path: &'a str,
22    pub command: Command<'a>,
23    pub datetime: OffsetDateTime,
24    pub sync: bool,
25}
26
27#[maybe_async]
28impl<'a> Request for Reqwest<'a> {
29    type Response = reqwest::Response;
30    type HeaderMap = reqwest::header::HeaderMap;
31
32    fn command(&self) -> Command {
33        self.command.clone()
34    }
35
36    fn path(&self) -> String {
37        self.path.to_string()
38    }
39
40    fn datetime(&self) -> OffsetDateTime {
41        self.datetime
42    }
43
44    fn bucket(&self) -> Bucket {
45        self.bucket.clone()
46    }
47
48    async fn response(&self) -> Result<Response, S3Error> {
49        // Build headers
50        let headers = match self.headers() {
51            Ok(headers) => headers,
52            Err(e) => return Err(e),
53        };
54
55        let mut client_builder = Client::builder();
56        if let Some(timeout) = self.bucket.request_timeout {
57            client_builder = client_builder.timeout(timeout)
58        }
59
60        if cfg!(feature = "no-verify-ssl") {
61            cfg_if::cfg_if! {
62                if #[cfg(feature = "tokio-native-tls")]
63                {
64                    client_builder = client_builder.danger_accept_invalid_hostnames(true);
65                }
66
67            }
68
69            cfg_if::cfg_if! {
70                if #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))]
71                {
72                    client_builder = client_builder.danger_accept_invalid_certs(true);
73                }
74
75            }
76        }
77
78        let client = client_builder.build()?;
79
80        let method = match self.command.http_verb() {
81            HttpMethod::Delete => reqwest::Method::DELETE,
82            HttpMethod::Get => reqwest::Method::GET,
83            HttpMethod::Post => reqwest::Method::POST,
84            HttpMethod::Put => reqwest::Method::PUT,
85            HttpMethod::Head => reqwest::Method::HEAD,
86        };
87
88        let request = client
89            .request(method, self.url()?.as_str())
90            .headers(headers)
91            .body(self.request_body());
92
93        let response = request.send().await?;
94
95        if cfg!(feature = "fail-on-err") && !response.status().is_success() {
96            let status = response.status().as_u16();
97            let text = response.text().await?;
98            return Err(S3Error::Http(status, text));
99        }
100
101        Ok(response)
102    }
103
104    async fn response_data(&self, etag: bool) -> Result<ResponseData, S3Error> {
105        let response = self.response().await?;
106        let status_code = response.status().as_u16();
107        let mut headers = response.headers().clone();
108        let response_headers = headers
109            .clone()
110            .iter()
111            .map(|(k, v)| {
112                (
113                    k.to_string(),
114                    v.to_str()
115                        .unwrap_or("could-not-decode-header-value")
116                        .to_string(),
117                )
118            })
119            .collect::<HashMap<String, String>>();
120        let body_vec = if etag {
121            if let Some(etag) = headers.remove("ETag") {
122                Bytes::from(etag.to_str()?.to_string())
123            } else {
124                Bytes::from("")
125            }
126        } else {
127            response.bytes().await?
128        };
129        Ok(ResponseData::new(body_vec, status_code, response_headers))
130    }
131
132    async fn response_data_to_writer<T: tokio::io::AsyncWrite + Send + Unpin>(
133        &self,
134        writer: &mut T,
135    ) -> Result<u16, S3Error> {
136        use tokio::io::AsyncWriteExt;
137        let response = self.response().await?;
138
139        let status_code = response.status();
140        let mut stream = response.bytes_stream();
141
142        while let Some(item) = stream.next().await {
143            writer.write_all(&item?).await?;
144        }
145
146        Ok(status_code.as_u16())
147    }
148
149    async fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error> {
150        let response = self.response().await?;
151        let status_code = response.status().as_u16();
152        let headers = response.headers().clone();
153        Ok((headers, status_code))
154    }
155
156    async fn response_data_to_stream(&self) -> Result<ResponseDataStream, S3Error> {
157        let response = self.response().await?;
158        let status_code = response.status();
159        let stream = response.bytes_stream().filter_map(|b| b.ok());
160
161        Ok(ResponseDataStream {
162            bytes: Box::pin(stream),
163            status_code: status_code.as_u16(),
164        })
165    }
166}
167
168impl<'a> Reqwest<'a> {
169    pub fn new<'b>(
170        bucket: &'b Bucket,
171        path: &'b str,
172        command: Command<'b>,
173    ) -> Result<Reqwest<'b>, S3Error> {
174        bucket.credentials_refresh()?;
175        Ok(Reqwest {
176            bucket,
177            path,
178            command,
179            datetime: OffsetDateTime::now_utc(),
180            sync: false,
181        })
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use crate::bucket::Bucket;
188    use crate::command::Command;
189    use crate::request::tokio_backend::Reqwest;
190    use crate::request::Request;
191    use awscreds::Credentials;
192    use http::header::{HOST, RANGE};
193
194    // Fake keys - otherwise using Credentials::default will use actual user
195    // credentials if they exist.
196    fn fake_credentials() -> Credentials {
197        let access_key = "AKIAIOSFODNN7EXAMPLE";
198        let secert_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY";
199        Credentials::new(Some(access_key), Some(secert_key), None, None, None).unwrap()
200    }
201
202    #[test]
203    fn url_uses_https_by_default() {
204        let region = "custom-region".parse().unwrap();
205        let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap();
206        let path = "/my-first/path";
207        let request = Reqwest::new(&bucket, path, Command::GetObject).unwrap();
208
209        assert_eq!(request.url().unwrap().scheme(), "https");
210
211        let headers = request.headers().unwrap();
212        let host = headers.get(HOST).unwrap();
213
214        assert_eq!(*host, "my-first-bucket.custom-region".to_string());
215    }
216
217    #[test]
218    fn url_uses_https_by_default_path_style() {
219        let region = "custom-region".parse().unwrap();
220        let bucket = Bucket::new("my-first-bucket", region, fake_credentials())
221            .unwrap()
222            .with_path_style();
223        let path = "/my-first/path";
224        let request = Reqwest::new(&bucket, path, Command::GetObject).unwrap();
225
226        assert_eq!(request.url().unwrap().scheme(), "https");
227
228        let headers = request.headers().unwrap();
229        let host = headers.get(HOST).unwrap();
230
231        assert_eq!(*host, "custom-region".to_string());
232    }
233
234    #[test]
235    fn url_uses_scheme_from_custom_region_if_defined() {
236        let region = "http://custom-region".parse().unwrap();
237        let bucket = Bucket::new("my-second-bucket", region, fake_credentials()).unwrap();
238        let path = "/my-second/path";
239        let request = Reqwest::new(&bucket, path, Command::GetObject).unwrap();
240
241        assert_eq!(request.url().unwrap().scheme(), "http");
242
243        let headers = request.headers().unwrap();
244        let host = headers.get(HOST).unwrap();
245        assert_eq!(*host, "my-second-bucket.custom-region".to_string());
246    }
247
248    #[test]
249    fn url_uses_scheme_from_custom_region_if_defined_with_path_style() {
250        let region = "http://custom-region".parse().unwrap();
251        let bucket = Bucket::new("my-second-bucket", region, fake_credentials())
252            .unwrap()
253            .with_path_style();
254        let path = "/my-second/path";
255        let request = Reqwest::new(&bucket, path, Command::GetObject).unwrap();
256
257        assert_eq!(request.url().unwrap().scheme(), "http");
258
259        let headers = request.headers().unwrap();
260        let host = headers.get(HOST).unwrap();
261        assert_eq!(*host, "custom-region".to_string());
262    }
263
264    #[test]
265    fn test_get_object_range_header() {
266        let region = "http://custom-region".parse().unwrap();
267        let bucket = Bucket::new("my-second-bucket", region, fake_credentials())
268            .unwrap()
269            .with_path_style();
270        let path = "/my-second/path";
271
272        let request = Reqwest::new(
273            &bucket,
274            path,
275            Command::GetObjectRange {
276                start: 0,
277                end: None,
278            },
279        )
280        .unwrap();
281        let headers = request.headers().unwrap();
282        let range = headers.get(RANGE).unwrap();
283        assert_eq!(range, "bytes=0-");
284
285        let request = Reqwest::new(
286            &bucket,
287            path,
288            Command::GetObjectRange {
289                start: 0,
290                end: Some(1),
291            },
292        )
293        .unwrap();
294        let headers = request.headers().unwrap();
295        let range = headers.get(RANGE).unwrap();
296        assert_eq!(range, "bytes=0-1");
297    }
298}