1extern crate base64;
2extern crate md5;
3
4use bytes::Bytes;
5use maybe_async::maybe_async;
6use reqwest::{Client, Response};
7use std::collections::HashMap;
8use time::OffsetDateTime;
9
10use super::request_trait::{Request, ResponseData, ResponseDataStream};
11use crate::bucket::Bucket;
12use crate::command::Command;
13use crate::command::HttpMethod;
14use crate::error::S3Error;
15
16use tokio_stream::StreamExt;
17
18pub struct Reqwest<'a> {
20 pub bucket: &'a Bucket,
21 pub path: &'a str,
22 pub command: Command<'a>,
23 pub datetime: OffsetDateTime,
24 pub sync: bool,
25}
26
27#[maybe_async]
28impl<'a> Request for Reqwest<'a> {
29 type Response = reqwest::Response;
30 type HeaderMap = reqwest::header::HeaderMap;
31
32 fn command(&self) -> Command {
33 self.command.clone()
34 }
35
36 fn path(&self) -> String {
37 self.path.to_string()
38 }
39
40 fn datetime(&self) -> OffsetDateTime {
41 self.datetime
42 }
43
44 fn bucket(&self) -> Bucket {
45 self.bucket.clone()
46 }
47
48 async fn response(&self) -> Result<Response, S3Error> {
49 let headers = match self.headers() {
51 Ok(headers) => headers,
52 Err(e) => return Err(e),
53 };
54
55 let mut client_builder = Client::builder();
56 if let Some(timeout) = self.bucket.request_timeout {
57 client_builder = client_builder.timeout(timeout)
58 }
59
60 if cfg!(feature = "no-verify-ssl") {
61 cfg_if::cfg_if! {
62 if #[cfg(feature = "tokio-native-tls")]
63 {
64 client_builder = client_builder.danger_accept_invalid_hostnames(true);
65 }
66
67 }
68
69 cfg_if::cfg_if! {
70 if #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))]
71 {
72 client_builder = client_builder.danger_accept_invalid_certs(true);
73 }
74
75 }
76 }
77
78 let client = client_builder.build()?;
79
80 let method = match self.command.http_verb() {
81 HttpMethod::Delete => reqwest::Method::DELETE,
82 HttpMethod::Get => reqwest::Method::GET,
83 HttpMethod::Post => reqwest::Method::POST,
84 HttpMethod::Put => reqwest::Method::PUT,
85 HttpMethod::Head => reqwest::Method::HEAD,
86 };
87
88 let request = client
89 .request(method, self.url()?.as_str())
90 .headers(headers)
91 .body(self.request_body());
92
93 let response = request.send().await?;
94
95 if cfg!(feature = "fail-on-err") && !response.status().is_success() {
96 let status = response.status().as_u16();
97 let text = response.text().await?;
98 return Err(S3Error::Http(status, text));
99 }
100
101 Ok(response)
102 }
103
104 async fn response_data(&self, etag: bool) -> Result<ResponseData, S3Error> {
105 let response = self.response().await?;
106 let status_code = response.status().as_u16();
107 let mut headers = response.headers().clone();
108 let response_headers = headers
109 .clone()
110 .iter()
111 .map(|(k, v)| {
112 (
113 k.to_string(),
114 v.to_str()
115 .unwrap_or("could-not-decode-header-value")
116 .to_string(),
117 )
118 })
119 .collect::<HashMap<String, String>>();
120 let body_vec = if etag {
121 if let Some(etag) = headers.remove("ETag") {
122 Bytes::from(etag.to_str()?.to_string())
123 } else {
124 Bytes::from("")
125 }
126 } else {
127 response.bytes().await?
128 };
129 Ok(ResponseData::new(body_vec, status_code, response_headers))
130 }
131
132 async fn response_data_to_writer<T: tokio::io::AsyncWrite + Send + Unpin>(
133 &self,
134 writer: &mut T,
135 ) -> Result<u16, S3Error> {
136 use tokio::io::AsyncWriteExt;
137 let response = self.response().await?;
138
139 let status_code = response.status();
140 let mut stream = response.bytes_stream();
141
142 while let Some(item) = stream.next().await {
143 writer.write_all(&item?).await?;
144 }
145
146 Ok(status_code.as_u16())
147 }
148
149 async fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error> {
150 let response = self.response().await?;
151 let status_code = response.status().as_u16();
152 let headers = response.headers().clone();
153 Ok((headers, status_code))
154 }
155
156 async fn response_data_to_stream(&self) -> Result<ResponseDataStream, S3Error> {
157 let response = self.response().await?;
158 let status_code = response.status();
159 let stream = response.bytes_stream().filter_map(|b| b.ok());
160
161 Ok(ResponseDataStream {
162 bytes: Box::pin(stream),
163 status_code: status_code.as_u16(),
164 })
165 }
166}
167
168impl<'a> Reqwest<'a> {
169 pub fn new<'b>(
170 bucket: &'b Bucket,
171 path: &'b str,
172 command: Command<'b>,
173 ) -> Result<Reqwest<'b>, S3Error> {
174 bucket.credentials_refresh()?;
175 Ok(Reqwest {
176 bucket,
177 path,
178 command,
179 datetime: OffsetDateTime::now_utc(),
180 sync: false,
181 })
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use crate::bucket::Bucket;
188 use crate::command::Command;
189 use crate::request::tokio_backend::Reqwest;
190 use crate::request::Request;
191 use awscreds::Credentials;
192 use http::header::{HOST, RANGE};
193
194 fn fake_credentials() -> Credentials {
197 let access_key = "AKIAIOSFODNN7EXAMPLE";
198 let secert_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY";
199 Credentials::new(Some(access_key), Some(secert_key), None, None, None).unwrap()
200 }
201
202 #[test]
203 fn url_uses_https_by_default() {
204 let region = "custom-region".parse().unwrap();
205 let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap();
206 let path = "/my-first/path";
207 let request = Reqwest::new(&bucket, path, Command::GetObject).unwrap();
208
209 assert_eq!(request.url().unwrap().scheme(), "https");
210
211 let headers = request.headers().unwrap();
212 let host = headers.get(HOST).unwrap();
213
214 assert_eq!(*host, "my-first-bucket.custom-region".to_string());
215 }
216
217 #[test]
218 fn url_uses_https_by_default_path_style() {
219 let region = "custom-region".parse().unwrap();
220 let bucket = Bucket::new("my-first-bucket", region, fake_credentials())
221 .unwrap()
222 .with_path_style();
223 let path = "/my-first/path";
224 let request = Reqwest::new(&bucket, path, Command::GetObject).unwrap();
225
226 assert_eq!(request.url().unwrap().scheme(), "https");
227
228 let headers = request.headers().unwrap();
229 let host = headers.get(HOST).unwrap();
230
231 assert_eq!(*host, "custom-region".to_string());
232 }
233
234 #[test]
235 fn url_uses_scheme_from_custom_region_if_defined() {
236 let region = "http://custom-region".parse().unwrap();
237 let bucket = Bucket::new("my-second-bucket", region, fake_credentials()).unwrap();
238 let path = "/my-second/path";
239 let request = Reqwest::new(&bucket, path, Command::GetObject).unwrap();
240
241 assert_eq!(request.url().unwrap().scheme(), "http");
242
243 let headers = request.headers().unwrap();
244 let host = headers.get(HOST).unwrap();
245 assert_eq!(*host, "my-second-bucket.custom-region".to_string());
246 }
247
248 #[test]
249 fn url_uses_scheme_from_custom_region_if_defined_with_path_style() {
250 let region = "http://custom-region".parse().unwrap();
251 let bucket = Bucket::new("my-second-bucket", region, fake_credentials())
252 .unwrap()
253 .with_path_style();
254 let path = "/my-second/path";
255 let request = Reqwest::new(&bucket, path, Command::GetObject).unwrap();
256
257 assert_eq!(request.url().unwrap().scheme(), "http");
258
259 let headers = request.headers().unwrap();
260 let host = headers.get(HOST).unwrap();
261 assert_eq!(*host, "custom-region".to_string());
262 }
263
264 #[test]
265 fn test_get_object_range_header() {
266 let region = "http://custom-region".parse().unwrap();
267 let bucket = Bucket::new("my-second-bucket", region, fake_credentials())
268 .unwrap()
269 .with_path_style();
270 let path = "/my-second/path";
271
272 let request = Reqwest::new(
273 &bucket,
274 path,
275 Command::GetObjectRange {
276 start: 0,
277 end: None,
278 },
279 )
280 .unwrap();
281 let headers = request.headers().unwrap();
282 let range = headers.get(RANGE).unwrap();
283 assert_eq!(range, "bytes=0-");
284
285 let request = Reqwest::new(
286 &bucket,
287 path,
288 Command::GetObjectRange {
289 start: 0,
290 end: Some(1),
291 },
292 )
293 .unwrap();
294 let headers = request.headers().unwrap();
295 let range = headers.get(RANGE).unwrap();
296 assert_eq!(range, "bytes=0-1");
297 }
298}