s3/request/
request_trait.rs

1use hmac::Mac;
2use std::collections::HashMap;
3#[cfg(any(feature = "with-tokio", feature = "with-async-std"))]
4use std::pin::Pin;
5use time::format_description::well_known::Rfc2822;
6use time::OffsetDateTime;
7use url::Url;
8
9use crate::bucket::Bucket;
10use crate::command::Command;
11use crate::error::S3Error;
12use crate::signing;
13use crate::LONG_DATETIME;
14use bytes::Bytes;
15use http::header::{
16    HeaderName, ACCEPT, AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE, DATE, HOST, RANGE,
17};
18use http::HeaderMap;
19use std::fmt::Write as _;
20
21#[cfg(feature = "with-async-std")]
22use futures_util::Stream;
23
24#[cfg(feature = "with-tokio")]
25use tokio_stream::Stream;
26
27#[derive(Debug)]
28
29pub struct ResponseData {
30    bytes: Bytes,
31    status_code: u16,
32    headers: HashMap<String, String>,
33}
34
35#[cfg(any(feature = "with-tokio", feature = "with-async-std"))]
36pub struct ResponseDataStream {
37    pub bytes: Pin<Box<dyn Stream<Item = Bytes>>>,
38    pub status_code: u16,
39}
40
41#[cfg(any(feature = "with-tokio", feature = "with-async-std"))]
42impl ResponseDataStream {
43    pub fn bytes(&mut self) -> &mut Pin<Box<dyn Stream<Item = Bytes>>> {
44        &mut self.bytes
45    }
46}
47
48impl From<ResponseData> for Vec<u8> {
49    fn from(data: ResponseData) -> Vec<u8> {
50        data.to_vec()
51    }
52}
53
54impl ResponseData {
55    pub fn new(bytes: Bytes, status_code: u16, headers: HashMap<String, String>) -> ResponseData {
56        ResponseData {
57            bytes,
58            status_code,
59            headers,
60        }
61    }
62
63    pub fn as_slice(&self) -> &[u8] {
64        &self.bytes
65    }
66
67    pub fn to_vec(self) -> Vec<u8> {
68        self.bytes.to_vec()
69    }
70
71    pub fn bytes(&self) -> &Bytes {
72        &self.bytes
73    }
74
75    pub fn status_code(&self) -> u16 {
76        self.status_code
77    }
78
79    pub fn as_str(&self) -> Result<&str, std::str::Utf8Error> {
80        std::str::from_utf8(self.as_slice())
81    }
82
83    pub fn to_string(&self) -> Result<String, std::str::Utf8Error> {
84        std::str::from_utf8(self.as_slice()).map(|s| s.to_string())
85    }
86
87    pub fn headers(&self) -> HashMap<String, String> {
88        self.headers.clone()
89    }
90}
91
92use std::fmt;
93
94impl fmt::Display for ResponseData {
95    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
96        write!(
97            f,
98            "Status code: {}\n Data: {}",
99            self.status_code(),
100            self.to_string()
101                .unwrap_or_else(|_| "Data could not be cast to UTF string".to_string())
102        )
103    }
104}
105
106#[maybe_async::maybe_async]
107pub trait Request {
108    type Response;
109    type HeaderMap;
110
111    async fn response(&self) -> Result<Self::Response, S3Error>;
112    async fn response_data(&self, etag: bool) -> Result<ResponseData, S3Error>;
113    #[cfg(feature = "with-tokio")]
114    async fn response_data_to_writer<T: tokio::io::AsyncWrite + Send + Unpin>(
115        &self,
116        writer: &mut T,
117    ) -> Result<u16, S3Error>;
118    #[cfg(feature = "with-async-std")]
119    async fn response_data_to_writer<T: futures_io::AsyncWrite + Send + Unpin>(
120        &self,
121        writer: &mut T,
122    ) -> Result<u16, S3Error>;
123    #[cfg(feature = "sync")]
124    fn response_data_to_writer<T: std::io::Write + Send>(
125        &self,
126        writer: &mut T,
127    ) -> Result<u16, S3Error>;
128    #[cfg(any(feature = "with-async-std", feature = "with-tokio"))]
129    async fn response_data_to_stream(&self) -> Result<ResponseDataStream, S3Error>;
130    async fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error>;
131    fn datetime(&self) -> OffsetDateTime;
132    fn bucket(&self) -> Bucket;
133    fn command(&self) -> Command;
134    fn path(&self) -> String;
135
136    fn signing_key(&self) -> Result<Vec<u8>, S3Error> {
137        signing::signing_key(
138            &self.datetime(),
139            &self
140                .bucket()
141                .secret_key()?
142                .expect("Secret key must be provided to sign headers, found None"),
143            &self.bucket().region(),
144            "s3",
145        )
146    }
147
148    fn request_body(&self) -> Vec<u8> {
149        if let Command::PutObject { content, .. } = self.command() {
150            Vec::from(content)
151        } else if let Command::PutObjectTagging { tags } = self.command() {
152            Vec::from(tags)
153        } else if let Command::UploadPart { content, .. } = self.command() {
154            Vec::from(content)
155        } else if let Command::CompleteMultipartUpload { data, .. } = &self.command() {
156            let body = data.to_string();
157            println!("CompleteMultipartUpload: {}", body);
158            body.as_bytes().to_vec()
159        } else if let Command::CreateBucket { config } = &self.command() {
160            if let Some(payload) = config.location_constraint_payload() {
161                Vec::from(payload)
162            } else {
163                Vec::new()
164            }
165        } else {
166            Vec::new()
167        }
168    }
169
170    fn long_date(&self) -> Result<String, S3Error> {
171        Ok(self.datetime().format(LONG_DATETIME)?)
172    }
173
174    fn string_to_sign(&self, request: &str) -> Result<String, S3Error> {
175        match self.command() {
176            Command::PresignPost { post_policy, .. } => Ok(post_policy),
177            _ => Ok(signing::string_to_sign(
178                &self.datetime(),
179                &self.bucket().region(),
180                request,
181            )?),
182        }
183    }
184
185    fn host_header(&self) -> String {
186        self.bucket().host()
187    }
188
189    fn presigned(&self) -> Result<String, S3Error> {
190        let (expiry, custom_headers, custom_queries) = match self.command() {
191            Command::PresignGet {
192                expiry_secs,
193                custom_queries,
194            } => (expiry_secs, None, custom_queries),
195            Command::PresignPut {
196                expiry_secs,
197                custom_headers,
198            } => (expiry_secs, custom_headers, None),
199            Command::PresignDelete { expiry_secs } => (expiry_secs, None, None),
200            _ => unreachable!(),
201        };
202
203        Ok(format!(
204            "{}&X-Amz-Signature={}",
205            self.presigned_url_no_sig(expiry, custom_headers.as_ref(), custom_queries.as_ref())?,
206            self.presigned_authorization(custom_headers.as_ref())?
207        ))
208    }
209
210    fn presigned_authorization(
211        &self,
212        custom_headers: Option<&HeaderMap>,
213    ) -> Result<String, S3Error> {
214        let mut headers = HeaderMap::new();
215        let host_header = self.host_header();
216        headers.insert(HOST, host_header.parse()?);
217        if let Some(custom_headers) = custom_headers {
218            for (k, v) in custom_headers.iter() {
219                headers.insert(k.clone(), v.clone());
220            }
221        }
222        let canonical_request = self.presigned_canonical_request(&headers)?;
223        let string_to_sign = self.string_to_sign(&canonical_request)?;
224        let mut hmac = signing::HmacSha256::new_from_slice(&self.signing_key()?)?;
225        hmac.update(string_to_sign.as_bytes());
226        let signature = hex::encode(hmac.finalize().into_bytes());
227        // let signed_header = signing::signed_header_string(&headers);
228        Ok(signature)
229    }
230
231    fn presigned_canonical_request(&self, headers: &HeaderMap) -> Result<String, S3Error> {
232        let (expiry, custom_headers, custom_queries) = match self.command() {
233            Command::PresignGet {
234                expiry_secs,
235                custom_queries,
236            } => (expiry_secs, None, custom_queries),
237            Command::PresignPut {
238                expiry_secs,
239                custom_headers,
240            } => (expiry_secs, custom_headers, None),
241            Command::PresignDelete { expiry_secs } => (expiry_secs, None, None),
242            _ => unreachable!(),
243        };
244
245        signing::canonical_request(
246            &self.command().http_verb().to_string(),
247            &self.presigned_url_no_sig(expiry, custom_headers.as_ref(), custom_queries.as_ref())?,
248            headers,
249            "UNSIGNED-PAYLOAD",
250        )
251    }
252
253    fn presigned_url_no_sig(
254        &self,
255        expiry: u32,
256        custom_headers: Option<&HeaderMap>,
257        custom_queries: Option<&HashMap<String, String>>,
258    ) -> Result<Url, S3Error> {
259        let bucket = self.bucket();
260        let token = if let Some(security_token) = bucket.security_token()? {
261            Some(security_token)
262        } else {
263            bucket.session_token()?
264        };
265        let url = Url::parse(&format!(
266            "{}{}{}",
267            self.url()?,
268            &signing::authorization_query_params_no_sig(
269                &self.bucket().access_key()?.unwrap_or_default(),
270                &self.datetime(),
271                &self.bucket().region(),
272                expiry,
273                custom_headers,
274                token.as_ref()
275            )?,
276            &signing::flatten_queries(custom_queries)?,
277        ))?;
278
279        Ok(url)
280    }
281
282    fn url(&self) -> Result<Url, S3Error> {
283        let mut url_str = self.bucket().url();
284
285        if let Command::CreateBucket { .. } = self.command() {
286            return Ok(Url::parse(&url_str)?);
287        }
288
289        let path = if self.path().starts_with('/') {
290            self.path()[1..].to_string()
291        } else {
292            self.path()[..].to_string()
293        };
294
295        url_str.push('/');
296        url_str.push_str(&signing::uri_encode(&path, false));
297
298        // Append to url_path
299        #[allow(clippy::collapsible_match)]
300        match self.command() {
301            Command::InitiateMultipartUpload { .. } | Command::ListMultipartUploads { .. } => {
302                url_str.push_str("?uploads")
303            }
304            Command::AbortMultipartUpload { upload_id } => {
305                write!(url_str, "?uploadId={}", upload_id).expect("Could not write to url_str");
306            }
307            Command::CompleteMultipartUpload { upload_id, .. } => {
308                write!(url_str, "?uploadId={}", upload_id).expect("Could not write to url_str");
309            }
310            Command::GetObjectTorrent => url_str.push_str("?torrent"),
311            Command::PutObject { multipart, .. } => {
312                if let Some(multipart) = multipart {
313                    url_str.push_str(&multipart.query_string())
314                }
315            }
316            _ => {}
317        }
318
319        let mut url = Url::parse(&url_str)?;
320
321        for (key, value) in &self.bucket().extra_query {
322            url.query_pairs_mut().append_pair(key, value);
323        }
324
325        if let Command::ListObjectsV2 {
326            prefix,
327            delimiter,
328            continuation_token,
329            start_after,
330            max_keys,
331        } = self.command().clone()
332        {
333            let mut query_pairs = url.query_pairs_mut();
334            delimiter.map(|d| query_pairs.append_pair("delimiter", &d));
335
336            query_pairs.append_pair("prefix", &prefix);
337            query_pairs.append_pair("list-type", "2");
338            if let Some(token) = continuation_token {
339                query_pairs.append_pair("continuation-token", &token);
340            }
341            if let Some(start_after) = start_after {
342                query_pairs.append_pair("start-after", &start_after);
343            }
344            if let Some(max_keys) = max_keys {
345                query_pairs.append_pair("max-keys", &max_keys.to_string());
346            }
347        }
348
349        if let Command::ListObjects {
350            prefix,
351            delimiter,
352            marker,
353            max_keys,
354        } = self.command().clone()
355        {
356            let mut query_pairs = url.query_pairs_mut();
357            delimiter.map(|d| query_pairs.append_pair("delimiter", &d));
358
359            query_pairs.append_pair("prefix", &prefix);
360            if let Some(marker) = marker {
361                query_pairs.append_pair("marker", &marker);
362            }
363            if let Some(max_keys) = max_keys {
364                query_pairs.append_pair("max-keys", &max_keys.to_string());
365            }
366        }
367
368        match self.command() {
369            Command::ListMultipartUploads {
370                prefix,
371                delimiter,
372                key_marker,
373                max_uploads,
374            } => {
375                let mut query_pairs = url.query_pairs_mut();
376                delimiter.map(|d| query_pairs.append_pair("delimiter", d));
377                if let Some(prefix) = prefix {
378                    query_pairs.append_pair("prefix", prefix);
379                }
380                if let Some(key_marker) = key_marker {
381                    query_pairs.append_pair("key-marker", &key_marker);
382                }
383                if let Some(max_uploads) = max_uploads {
384                    query_pairs.append_pair("max-uploads", max_uploads.to_string().as_str());
385                }
386            }
387            Command::PutObjectTagging { .. }
388            | Command::GetObjectTagging
389            | Command::DeleteObjectTagging => {
390                url.query_pairs_mut().append_pair("tagging", "");
391            }
392            _ => {}
393        }
394
395        Ok(url)
396    }
397
398    fn canonical_request(&self, headers: &HeaderMap) -> Result<String, S3Error> {
399        signing::canonical_request(
400            &self.command().http_verb().to_string(),
401            &self.url()?,
402            headers,
403            &self.command().sha256(),
404        )
405    }
406
407    fn authorization(&self, headers: &HeaderMap) -> Result<String, S3Error> {
408        let canonical_request = self.canonical_request(headers)?;
409        let string_to_sign = self.string_to_sign(&canonical_request)?;
410        let mut hmac = signing::HmacSha256::new_from_slice(&self.signing_key()?)?;
411        hmac.update(string_to_sign.as_bytes());
412        let signature = hex::encode(hmac.finalize().into_bytes());
413        let signed_header = signing::signed_header_string(headers);
414        signing::authorization_header(
415            &self.bucket().access_key()?.expect("No access_key provided"),
416            &self.datetime(),
417            &self.bucket().region(),
418            &signed_header,
419            &signature,
420        )
421    }
422
423    fn headers(&self) -> Result<HeaderMap, S3Error> {
424        // Generate this once, but it's used in more than one place.
425        let sha256 = self.command().sha256();
426
427        // Start with extra_headers, that way our headers replace anything with
428        // the same name.
429
430        let mut headers = HeaderMap::new();
431
432        for (k, v) in self.bucket().extra_headers.iter() {
433            headers.insert(k.clone(), v.clone());
434        }
435
436        let host_header = self.host_header();
437
438        headers.insert(HOST, host_header.parse()?);
439
440        match self.command() {
441            Command::CopyObject { from } => {
442                headers.insert(HeaderName::from_static("x-amz-copy-source"), from.parse()?);
443            }
444            Command::ListObjects { .. } => {}
445            Command::ListObjectsV2 { .. } => {}
446            Command::GetObject => {}
447            Command::GetObjectTagging => {}
448            Command::GetBucketLocation => {}
449            _ => {
450                headers.insert(
451                    CONTENT_LENGTH,
452                    self.command().content_length().to_string().parse()?,
453                );
454                headers.insert(CONTENT_TYPE, self.command().content_type().parse()?);
455            }
456        }
457        headers.insert(
458            HeaderName::from_static("x-amz-content-sha256"),
459            sha256.parse()?,
460        );
461        headers.insert(
462            HeaderName::from_static("x-amz-date"),
463            self.long_date()?.parse()?,
464        );
465
466        if let Some(session_token) = self.bucket().session_token()? {
467            headers.insert(
468                HeaderName::from_static("x-amz-security-token"),
469                session_token.parse()?,
470            );
471        } else if let Some(security_token) = self.bucket().security_token()? {
472            headers.insert(
473                HeaderName::from_static("x-amz-security-token"),
474                security_token.parse()?,
475            );
476        }
477
478        if let Command::PutObjectTagging { tags } = self.command() {
479            let digest = md5::compute(tags);
480            let hash = base64::encode(digest.as_ref());
481            headers.insert(HeaderName::from_static("content-md5"), hash.parse()?);
482        } else if let Command::PutObject { content, .. } = self.command() {
483            let digest = md5::compute(content);
484            let hash = base64::encode(digest.as_ref());
485            headers.insert(HeaderName::from_static("content-md5"), hash.parse()?);
486        } else if let Command::UploadPart { content, .. } = self.command() {
487            let digest = md5::compute(content);
488            let hash = base64::encode(digest.as_ref());
489            headers.insert(HeaderName::from_static("content-md5"), hash.parse()?);
490        } else if let Command::GetObject {} = self.command() {
491            headers.insert(ACCEPT, "application/octet-stream".to_string().parse()?);
492        // headers.insert(header::ACCEPT_CHARSET, HeaderValue::from_str("UTF-8")?);
493        } else if let Command::GetObjectRange { start, end } = self.command() {
494            headers.insert(ACCEPT, "application/octet-stream".to_string().parse()?);
495
496            let mut range = format!("bytes={}-", start);
497
498            if let Some(end) = end {
499                range.push_str(&end.to_string());
500            }
501
502            headers.insert(RANGE, range.parse()?);
503        } else if let Command::CreateBucket { ref config } = self.command() {
504            config.add_headers(&mut headers)?;
505        }
506
507        // This must be last, as it signs the other headers, omitted if no secret key is provided
508        if self.bucket().secret_key()?.is_some() {
509            let authorization = self.authorization(&headers)?;
510            headers.insert(AUTHORIZATION, authorization.parse()?);
511        }
512
513        // The format of RFC2822 is somewhat malleable, so including it in
514        // signed headers can cause signature mismatches. We do include the
515        // X-Amz-Date header, so requests are still properly limited to a date
516        // range and can't be used again e.g. reply attacks. Adding this header
517        // after the generation of the Authorization header leaves it out of
518        // the signed headers.
519        headers.insert(DATE, self.datetime().format(&Rfc2822)?.parse()?);
520
521        Ok(headers)
522    }
523}