s3handler/tokio_async/primitives/
s3.rs

1use async_trait::async_trait;
2use base64::encode;
3use bytes::{Bytes, BytesMut};
4use chrono::prelude::*;
5use dyn_clone::DynClone;
6use futures::future::join_all;
7use hmac::{Hmac, Mac};
8use reqwest::{
9    header::{self, HeaderMap, HeaderName, HeaderValue},
10    Client, Method, Request, Response, Url,
11};
12use sha2::Digest;
13use sha2::Sha256 as sha2_256;
14use std::fmt;
15use url::form_urlencoded;
16
17use super::canal::{Canal, PoolType};
18use crate::blocking::{AuthType, Handler};
19use crate::error::Error;
20use crate::tokio_async::traits::{DataPool, Filter, S3Folder};
21use crate::utils::{
22    s3object_list_xml_parser, upload_id_xml_parser, S3Convert, S3Object, UrlStyle, DEFAULT_REGION,
23};
24
25type UTCTime = DateTime<Utc>;
26
27pub trait Signer: Send + Sync + DynClone + fmt::Debug {
28    /// This method will setup the header and put the authorize string
29    fn sign(&self, _request: &mut Request, _now: &UTCTime) {
30        unimplemented!()
31    }
32
33    /// This method will be called once the resource change the region stored
34    fn update_region(&mut self, _region: String) {}
35}
36
37dyn_clone::clone_trait_object!(Signer);
38
39/// A dummy signer if you do not want to sign any request to access a public resource
40#[derive(Clone, Debug)]
41pub struct DummySigner {}
42
43impl Signer for DummySigner {
44    fn sign(&self, _requests: &mut Request, _now: &UTCTime) {}
45}
46
47#[derive(Clone, Debug)]
48pub struct V2AuthSigner {
49    pub access_key: String,
50    pub secret_key: String,
51    pub auth_str: String,
52    pub special_header_prefix: String,
53}
54
55#[allow(dead_code)]
56impl V2AuthSigner {
57    /// new V2 auth signer compatible with AWS and CEPH
58    pub fn new(access_key: String, secret_key: String) -> Self {
59        V2AuthSigner {
60            access_key,
61            secret_key,
62            auth_str: "AWS".to_string(),
63            special_header_prefix: "x-amz".to_string(),
64        }
65    }
66    /// Setup the Auth string, if you are using customized S3
67    /// Default is "AWS"
68    pub fn auth_str(mut self, auth_str: String) -> Self {
69        self.auth_str = auth_str;
70        self
71    }
72
73    /// Setup the Special header prefix, if you are using customized S3
74    /// Default is "x-amz"
75    pub fn special_header_prefix(mut self, special_header_prefix: String) -> Self {
76        self.special_header_prefix = special_header_prefix;
77        self
78    }
79}
80
81impl Signer for V2AuthSigner {
82    fn sign(&self, request: &mut Request, _now: &UTCTime) {
83        let auth_string = format!(
84            "{} {}:{}",
85            self.auth_str,
86            self.access_key,
87            <Request as V2Signature>::sign(request, &self.secret_key)
88        );
89        let headers = request.headers_mut();
90        headers.insert(header::AUTHORIZATION, auth_string.parse().unwrap());
91    }
92}
93
94#[derive(Clone, Debug)]
95pub struct V4AuthSigner {
96    pub access_key: String,
97    pub secret_key: String,
98    pub region: String,
99    pub service: String,
100    pub action: String,
101    pub auth_str: String,
102    pub special_header_prefix: String,
103}
104
105#[allow(dead_code)]
106impl V4AuthSigner {
107    /// new V4 Auth signer for AWS and CEPH
108    pub fn new(access_key: String, secret_key: String, region: String) -> Self {
109        V4AuthSigner {
110            access_key,
111            secret_key,
112            region,
113            service: "s3".to_string(),
114            action: "aws4_request".to_string(),
115            auth_str: "AWS4-HMAC-SHA256".to_string(),
116            special_header_prefix: "x-amz".to_string(),
117        }
118    }
119    /// Default is "us-east-1"
120    pub fn region(mut self, region: String) -> Self {
121        self.region = region;
122        self
123    }
124    /// Default is "s3"
125    pub fn service(mut self, service: String) -> Self {
126        self.service = service;
127        self
128    }
129    /// Default is "aws4_request"
130    pub fn action(mut self, action: String) -> Self {
131        self.action = action;
132        self
133    }
134    /// Setup the Auth string, if you are using customized S3
135    /// Default is "AWS4-HMAC-SHA256"
136    pub fn auth_str(mut self, auth_str: String) -> Self {
137        self.auth_str = auth_str;
138        self
139    }
140
141    /// Setup the Special header prefix, if you are using customized S3
142    /// Default is "x-amz"
143    pub fn special_header_prefix(mut self, special_header_prefix: String) -> Self {
144        self.special_header_prefix = special_header_prefix;
145        self
146    }
147}
148
149impl Signer for V4AuthSigner {
150    fn sign(&self, request: &mut Request, now: &UTCTime) {
151        let SignatureInfo {
152            signed_headers,
153            signature,
154        } = <Request as V4Signature>::sign(
155            request,
156            &self.auth_str,
157            now,
158            &self.secret_key,
159            &self.region,
160            &self.service,
161            &self.action,
162        );
163        let authorize_string = format!(
164            "{} Credential={}/{}/{}/{}/{}, SignedHeaders={}, Signature={}",
165            self.auth_str,
166            self.access_key,
167            now.format("%Y%m%d"),
168            self.region,
169            self.service,
170            self.action,
171            signed_headers,
172            signature
173        );
174        let headers = request.headers_mut();
175        headers.insert(header::AUTHORIZATION, authorize_string.parse().unwrap());
176    }
177    fn update_region(&mut self, region: String) {
178        self.region = region;
179    }
180}
181#[derive(Clone, Debug)]
182pub struct S3Pool {
183    pub host: String,
184    /// To use https or not, please note that integrity is secured by S3 protocol.
185    /// If the confidentiality is not under concerned, the http is good.
186    pub secure: bool,
187    /// Default will be Path style,
188    /// because Virtual hosted URLs may be supported for non-SSL requests only.
189    pub url_style: UrlStyle,
190
191    /// The part size for multipart, default disabled.
192    /// If Some the pull/push will check out the object size first and do mulitpart
193    /// If None download and upload will be in one part
194    pub part_size: Option<usize>,
195
196    client: Client,
197
198    /// The signer to adapt different protocol of data source
199    pub signer: Box<dyn Signer>,
200
201    objects: Vec<S3Object>,
202    filter: Option<Filter>,
203    is_truncated: bool,
204}
205
206impl S3Pool {
207    pub fn bucket(self, bucket_name: &str) -> Canal {
208        Canal {
209            up_pool: Some(Box::new(self)),
210            down_pool: None,
211            upstream_object: Some(bucket_name.into()),
212            downstream_object: None,
213            default: PoolType::UpPool,
214            filter: None,
215        }
216    }
217
218    pub fn resource(self, s3_object: S3Object) -> Canal {
219        Canal {
220            up_pool: Some(Box::new(self)),
221            down_pool: None,
222            upstream_object: Some(s3_object),
223            downstream_object: None,
224            default: PoolType::UpPool,
225            filter: None,
226        }
227    }
228
229    pub fn new(host: String) -> Self {
230        S3Pool {
231            host,
232            secure: false,
233            url_style: UrlStyle::PATH,
234            client: Client::new(),
235            signer: Box::new(DummySigner {}),
236            part_size: None,
237            objects: Vec::with_capacity(1000),
238            filter: None,
239            is_truncated: false,
240        }
241    }
242
243    pub fn aws_v2(mut self, access_key: String, secret_key: String) -> Self {
244        self.signer = Box::new(V2AuthSigner::new(access_key, secret_key));
245        self.url_style = UrlStyle::PATH;
246        self
247    }
248
249    pub fn aws_v4(mut self, access_key: String, secret_key: String, region: String) -> Self {
250        self.signer = Box::new(V4AuthSigner::new(access_key, secret_key, region));
251        self.url_style = UrlStyle::HOST;
252        self
253    }
254
255    pub fn endpoint_and_virturalhost(&self, desc: S3Object) -> (String, Option<String>) {
256        let ((host, uri), virturalhost) = match self.url_style {
257            UrlStyle::PATH => (desc.path_style_links(self.host.clone()), None),
258            UrlStyle::HOST => {
259                let (host, uri) = desc.virtural_host_style_links(self.host.clone());
260                ((host.clone(), uri), Some(host))
261            }
262        };
263        if self.secure {
264            (format!("https://{}{}", host, uri), virturalhost)
265        } else {
266            (format!("http://{}{}", host, uri), virturalhost)
267        }
268    }
269
270    pub fn init_headers(
271        &self,
272        headers: &mut HeaderMap,
273        now: &UTCTime,
274        virturalhost: Option<String>,
275    ) {
276        headers.insert(
277            header::DATE,
278            HeaderValue::from_str(now.to_rfc2822().as_str()).unwrap(),
279        );
280        headers.insert(
281            header::USER_AGENT,
282            HeaderValue::from_static("Rust S3 Handler"),
283        );
284        if let Some(virtural_host) = virturalhost {
285            headers.insert(header::HOST, HeaderValue::from_str(&virtural_host).unwrap());
286        } else {
287            headers.insert(header::HOST, HeaderValue::from_str(&self.host).unwrap());
288        }
289    }
290
291    fn handle_list_response(&mut self, body: String) -> Result<(), Error> {
292        (self.objects, self.is_truncated) = s3object_list_xml_parser(&body)?;
293        Ok(())
294    }
295
296    pub fn part_size(mut self, s: usize) -> Self {
297        self.part_size = Some(s);
298        self
299    }
300
301    /// Init multipart upload session, and return `multipart_id`
302    async fn init_multipart_upload(
303        &self,
304        url: String,
305        virturalhost: Option<String>,
306    ) -> Result<String, Error> {
307        let url = format!("{}?uploads", url);
308        let mut request = self.client.post(&url).build()?;
309
310        let now = Utc::now();
311        self.init_headers(request.headers_mut(), &now, virturalhost);
312        self.signer.sign(&mut request, &now);
313
314        let r = self.client.execute(request).await?;
315
316        Ok(upload_id_xml_parser(&r.text().await?)?)
317    }
318
319    async fn generate_part_upload_requests(
320        &self,
321        desc: S3Object,
322        multipart_id: &str,
323        part_size: usize,
324        object: Bytes,
325    ) -> Result<Vec<Result<Response, reqwest::Error>>, Error> {
326        let mut part_number = 0;
327        let mut start = 0;
328        let mut req_list = vec![];
329        while start < object.len() {
330            part_number += 1;
331            let end = if start + part_size >= object.len() {
332                object.len()
333            } else {
334                start + part_size
335            };
336            let (endpoint, virtural_host) = self.endpoint_and_virturalhost(desc.clone());
337            let url = format!(
338                "{}?uploadId={}&partNumber={}",
339                endpoint, multipart_id, part_number
340            );
341
342            let mut request = self
343                .client
344                .put(&url)
345                .body(object.slice(start..end))
346                .build()?;
347
348            let now = Utc::now();
349            self.init_headers(request.headers_mut(), &now, virtural_host);
350            self.signer.sign(&mut request, &now);
351            req_list.push(self.client.execute(request));
352            start += part_size
353        }
354        Ok(join_all(req_list).await)
355    }
356
357    async fn complete_multi_part_upload(
358        &self,
359        reqs: Vec<Result<Response, reqwest::Error>>,
360        desc: S3Object,
361        multipart_id: &str,
362    ) -> Result<Response, Error> {
363        let mut content = "<CompleteMultipartUpload>".to_string();
364        for (idx, res) in reqs.into_iter().enumerate() {
365            let r = res?;
366            let etag = r.headers()[reqwest::header::ETAG]
367                .to_str()
368                .expect("unexpected etag from server");
369
370            content.push_str(&format!(
371                "<Part><PartNumber>{}</PartNumber><ETag>{}</ETag></Part>",
372                idx + 1,
373                etag
374            ));
375        }
376        content.push_str(&"</CompleteMultipartUpload>".to_string());
377        let (endpoint, virturalhost) = self.endpoint_and_virturalhost(desc);
378        let url = format!("{}?uploadId={}", endpoint, multipart_id);
379        let mut request = self.client.post(&url).body(content.into_bytes()).build()?;
380        let now = Utc::now();
381        self.init_headers(request.headers_mut(), &now, virturalhost);
382        self.signer.sign(&mut request, &now);
383        let r = self.client.execute(request).await?;
384        Ok(r)
385    }
386
387    async fn generate_part_download_requests(
388        &self,
389        desc: S3Object,
390        part_size: usize,
391    ) -> Result<Vec<Result<Response, reqwest::Error>>, Error> {
392        let mut start = 0;
393        let mut req_list = vec![];
394        while start < desc.size.unwrap() {
395            let end = if start + part_size >= desc.size.unwrap() {
396                desc.size.unwrap()
397            } else {
398                start + part_size
399            };
400            let (url, virturalhost) = self.endpoint_and_virturalhost(desc.clone());
401
402            let mut request = self.client.get(&url).build()?;
403
404            let headers = request.headers_mut();
405            headers.insert(
406                header::RANGE,
407                HeaderValue::from_str(&format!("bytes={}-{}", start, end - 1)).unwrap(),
408            );
409
410            let now = Utc::now();
411            self.init_headers(headers, &now, virturalhost);
412            self.signer.sign(&mut request, &now);
413            req_list.push(self.client.execute(request));
414            start += part_size
415        }
416        Ok(join_all(req_list).await)
417    }
418
419    async fn complete_multi_part_download(
420        &self,
421        reqs: Vec<Result<Response, reqwest::Error>>,
422    ) -> Result<Bytes, Error> {
423        let mut output = BytesMut::with_capacity(0);
424        for res in reqs.into_iter() {
425            let r = res?;
426            // TODO: no copy, check out a way of Bytes -> BytesMut then using unsplit
427            output.extend_from_slice(&r.bytes().await?);
428        }
429        Ok(output.into())
430    }
431
432    async fn update_list(&mut self) -> Result<S3Object, Error> {
433        let last_object = self.objects.remove(0);
434        let mut params = Vec::<(&str, String)>::new();
435        if let Some(key) = &last_object.key {
436            params.push(("list-type", "2".to_string()));
437            params.push((
438                "start-after",
439                key.to_string()
440                    .strip_prefix('/')
441                    .expect("key should start with /")
442                    .to_string(),
443            ));
444        }
445
446        let mut bucket_object = last_object.clone();
447        bucket_object.key = None;
448        let (endpoint, virturalhost) = self.endpoint_and_virturalhost(bucket_object);
449        if let Some(Filter::Prefix(prefix)) = &self.filter {
450            params.push(("prefix", prefix.to_string()));
451        }
452        let url = if !params.is_empty() {
453            Url::parse_with_params(&endpoint, &params)?
454        } else {
455            Url::parse(&endpoint)?
456        };
457        let mut request = Request::new(Method::GET, url);
458
459        let now = Utc::now();
460        self.init_headers(request.headers_mut(), &now, virturalhost);
461        self.signer.sign(&mut request, &now);
462        let body = self.client.execute(request).await?.text().await?;
463        // TODO: validate start-after
464        self.handle_list_response(body)?;
465        Ok(last_object)
466    }
467}
468
469impl From<Handler<'_>> for S3Pool {
470    fn from(handler: Handler) -> Self {
471        let secure = handler.is_secure();
472        let Handler {
473            host,
474            access_key,
475            secret_key,
476            region,
477            auth_type,
478            url_style,
479            ..
480        } = handler;
481
482        let signer: Box<dyn Signer> = match auth_type {
483            AuthType::AWS4 => Box::new(V4AuthSigner::new(
484                access_key.into(),
485                secret_key.into(),
486                region.unwrap_or_else(|| DEFAULT_REGION.to_string()),
487            )),
488            AuthType::AWS2 => Box::new(V2AuthSigner::new(access_key.into(), secret_key.into())),
489        };
490
491        Self {
492            host: host.into(),
493            secure,
494            url_style,
495            client: Client::new(),
496            signer,
497            part_size: Some(5242880),
498            objects: Vec::with_capacity(1000),
499            filter: None,
500            is_truncated: false,
501        }
502    }
503}
504
505impl From<&Handler<'_>> for S3Pool {
506    fn from(handler: &Handler) -> Self {
507        let secure = handler.is_secure();
508        let Handler {
509            host,
510            access_key,
511            secret_key,
512            region,
513            auth_type,
514            url_style,
515            ..
516        } = handler;
517
518        let signer: Box<dyn Signer> = match auth_type {
519            AuthType::AWS4 => Box::new(V4AuthSigner::new(
520                access_key.to_string(),
521                secret_key.to_string(),
522                region.clone().unwrap_or_else(|| DEFAULT_REGION.to_string()),
523            )),
524            AuthType::AWS2 => Box::new(V2AuthSigner::new(
525                access_key.to_string(),
526                secret_key.to_string(),
527            )),
528        };
529
530        Self {
531            host: host.to_string(),
532            secure,
533            url_style: url_style.clone(),
534            client: Client::new(),
535            signer,
536            part_size: Some(5242880),
537            objects: Vec::with_capacity(1000),
538            filter: None,
539            is_truncated: false,
540        }
541    }
542}
543
544#[async_trait]
545impl DataPool for S3Pool {
546    async fn push(&self, desc: S3Object, object: Bytes) -> Result<(), Error> {
547        let part_size = self.part_size.unwrap_or_default();
548        let _r = if part_size > 0 && part_size < object.len() {
549            let (endpoint, virturalhost) = self.endpoint_and_virturalhost(desc.clone());
550            let multipart_id = self.init_multipart_upload(endpoint, virturalhost).await?;
551
552            let reqs = self
553                .generate_part_upload_requests(desc.clone(), &multipart_id, part_size, object)
554                .await?;
555            self.complete_multi_part_upload(reqs, desc, &multipart_id)
556                .await?
557        } else {
558            let (endpoint, virturalhost) = self.endpoint_and_virturalhost(desc);
559            let mut request = self.client.put(&endpoint).body(object).build()?;
560
561            let now = Utc::now();
562            self.init_headers(request.headers_mut(), &now, virturalhost);
563            self.signer.sign(&mut request, &now);
564            self.client.execute(request).await?
565        };
566        // TODO validate _r status code
567        Ok(())
568    }
569
570    async fn pull(&self, mut desc: S3Object) -> Result<Bytes, Error> {
571        self.fetch_meta(&mut desc).await?;
572        let part_size = self.part_size.unwrap_or_default();
573        if part_size > 0 && part_size < desc.size.unwrap_or_default() {
574            let reqs = self
575                .generate_part_download_requests(desc, part_size)
576                .await?;
577            let output = self.complete_multi_part_download(reqs).await?;
578
579            Ok(output)
580        } else {
581            // TODO reuse the client setting and not only the reqest
582            let (endpoint, virturalhost) = self.endpoint_and_virturalhost(desc);
583            let mut request = Request::new(Method::GET, Url::parse(&endpoint)?);
584
585            let now = Utc::now();
586            self.init_headers(request.headers_mut(), &now, virturalhost);
587            self.signer.sign(&mut request, &now);
588
589            let r = self.client.execute(request).await?;
590            // TODO validate status code
591            Ok(r.bytes().await?)
592        }
593    }
594
595    async fn list(
596        &self,
597        index: Option<S3Object>,
598        filter: &Option<Filter>,
599    ) -> Result<Box<dyn S3Folder>, Error> {
600        let mut pool = self.clone();
601        let (endpoint, virturalhost) = self.endpoint_and_virturalhost(index.unwrap_or_default());
602        let url = if let Some(Filter::Prefix(prefix)) = filter {
603            Url::parse_with_params(&endpoint, &[("prefix", prefix)])?
604        } else {
605            Url::parse(&endpoint)?
606        };
607        let mut request = Request::new(Method::GET, url);
608
609        let now = Utc::now();
610        pool.init_headers(request.headers_mut(), &now, virturalhost);
611        pool.signer.sign(&mut request, &now);
612        let body = pool.client.execute(request).await?.text().await?;
613        pool.handle_list_response(body)?;
614
615        // passing filter if the list did not complete
616        if filter.is_some() && pool.is_truncated {
617            pool.filter = Some(filter.as_ref().unwrap().clone());
618        }
619        Ok(Box::new(pool))
620    }
621
622    async fn remove(&self, desc: S3Object) -> Result<(), Error> {
623        let (endpoint, virturalhost) = self.endpoint_and_virturalhost(desc);
624        let mut request = Request::new(Method::DELETE, Url::parse(&endpoint)?);
625
626        let now = Utc::now();
627        self.init_headers(request.headers_mut(), &now, virturalhost);
628        self.signer.sign(&mut request, &now);
629
630        let _r = self.client.execute(request).await?;
631        // TODO validate status code
632        Ok(())
633    }
634
635    fn check_scheme(&self, scheme: &str) -> Result<(), Error> {
636        if scheme.to_lowercase() != "s3" {
637            Err(Error::SchemeError())
638        } else {
639            Ok(())
640        }
641    }
642
643    async fn fetch_meta(&self, desc: &mut S3Object) -> Result<(), Error> {
644        let (endpoint, virturalhost) = self.endpoint_and_virturalhost(desc.clone());
645        let mut request = self.client.head(&endpoint).build()?;
646
647        let now = Utc::now();
648        self.init_headers(request.headers_mut(), &now, virturalhost);
649        self.signer.sign(&mut request, &now);
650
651        let r = self.client.execute(request).await?;
652        let headers = r.headers();
653        desc.etag = if headers.contains_key(reqwest::header::ETAG) {
654            Some(
655                headers[reqwest::header::ETAG]
656                    .to_str()?
657                    .to_string()
658                    .replace('"', ""),
659            )
660        } else {
661            None
662        };
663        desc.mtime = if headers.contains_key(HeaderName::from_lowercase(b"last-modified").unwrap())
664        {
665            Some(
666                headers[HeaderName::from_lowercase(b"last-modified").unwrap()]
667                    .to_str()?
668                    .into(),
669            )
670        } else {
671            None
672        };
673        desc.size = if headers.contains_key(reqwest::header::CONTENT_LENGTH) {
674            Some(
675                headers[reqwest::header::CONTENT_LENGTH]
676                    .to_str()?
677                    .parse::<usize>()
678                    .unwrap_or_default(),
679            )
680        } else {
681            None
682        };
683
684        // TODO: check out it is correct or not that the storage class is absent here
685
686        Ok(())
687    }
688}
689
690#[async_trait]
691impl S3Folder for S3Pool {
692    async fn next_object(&mut self) -> Result<Option<S3Object>, Error> {
693        loop {
694            if self.objects.is_empty() {
695                return Ok(None);
696            } else {
697                let obj = if self.is_truncated && self.objects.len() == 1 {
698                    let last = self.update_list().await?;
699                    last
700                } else {
701                    self.objects.remove(0)
702                };
703                if obj.key.is_some() {
704                    return Ok(Some(obj));
705                }
706            }
707        }
708    }
709}
710
711pub struct CanonicalHeadersInfo {
712    pub signed_headers: String,
713    pub canonical_headers: String,
714}
715
716pub struct CanonicalRequestInfo {
717    pub signed_headers: String,
718    pub canonical_request: String,
719}
720
721pub trait Canonical {
722    fn canonical_headers_info(&self) -> CanonicalHeadersInfo;
723    fn canonical_query_string(&self) -> String;
724    fn canonical_request_info(&self, payload_hash: &str) -> CanonicalRequestInfo;
725}
726
727impl Canonical for Request {
728    fn canonical_headers_info(&self) -> CanonicalHeadersInfo {
729        let mut canonical_headers = String::new();
730        let mut signed_headers = Vec::new();
731
732        let mut headers: Vec<(String, &str)> = self
733            .headers()
734            .iter()
735            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or_default()))
736            .collect();
737
738        headers.sort_by(|a, b| a.0.to_lowercase().as_str().cmp(b.0.to_lowercase().as_str()));
739        for h in headers {
740            canonical_headers.push_str(h.0.to_lowercase().as_str());
741            canonical_headers.push(':');
742            canonical_headers.push_str(h.1.trim());
743            canonical_headers.push('\n');
744            signed_headers.push(h.0.to_lowercase());
745        }
746        CanonicalHeadersInfo {
747            signed_headers: signed_headers.join(";"),
748            canonical_headers,
749        }
750    }
751
752    fn canonical_query_string(&self) -> String {
753        let mut encoded = form_urlencoded::Serializer::new(String::new());
754        let mut qs: Vec<(String, String)> = self
755            .url()
756            .query_pairs()
757            .into_iter()
758            .map(|(k, v)| (k.as_ref().to_owned(), v.as_ref().to_owned()))
759            .collect();
760
761        qs.sort_by(|x, y| x.0.cmp(&y.0));
762
763        for (key, value) in qs {
764            encoded.append_pair(&key, &value);
765        }
766
767        // There is a `~` in upload id, should be treated in a tricky way.
768        //
769        // >>>
770        // In the concatenated string, period characters (.) are not escaped.
771        // RFC 3986 considers the period character an unreserved character,
772        // so it is **not** URL encoded.
773        // >>>
774        //
775        // ref:
776        // https://docs.aws.amazon.com/general/latest/gr/signature-version-2.html#create-canonical-string
777        encoded.finish().replace("%7E", "~")
778    }
779
780    fn canonical_request_info(&self, payload_hash: &str) -> CanonicalRequestInfo {
781        let CanonicalHeadersInfo {
782            signed_headers,
783            canonical_headers,
784        } = self.canonical_headers_info();
785        CanonicalRequestInfo {
786            signed_headers: signed_headers.clone(),
787            canonical_request: format!(
788                "{}\n{}\n{}\n{}\n{}\n{}",
789                self.method().as_str(),
790                self.url().path(),
791                self.canonical_query_string(),
792                canonical_headers,
793                signed_headers,
794                payload_hash
795            ),
796        }
797    }
798}
799
800pub trait V2Signature
801where
802    Self: Canonical,
803{
804    fn string_to_signed(&self) -> String;
805    fn sign(&self, sign_key: &str) -> String;
806}
807
808impl V2Signature for Request {
809    fn string_to_signed(&self) -> String {
810        format!(
811            "{}\n\n\n{}\n{}{}",
812            self.method().as_str(),
813            self.headers().get(header::DATE).unwrap().to_str().unwrap(),
814            self.url().path(),
815            self.canonical_query_string()
816        )
817    }
818    fn sign(&self, sign_key: &str) -> String {
819        encode(&hmacsha1::hmac_sha1(
820            sign_key.as_bytes(),
821            <Request as V2Signature>::string_to_signed(self).as_bytes(),
822        ))
823    }
824}
825
826pub struct RequestHashInfo {
827    pub signed_headers: String,
828    pub sha256: String,
829}
830
831pub struct StringToSignInfo {
832    pub signed_headers: String,
833    pub string_to_signed: String,
834}
835
836pub struct SignatureInfo {
837    pub signed_headers: String,
838    pub signature: String,
839}
840
841pub trait V4Signature
842where
843    Self: Canonical,
844{
845    fn string_to_signed(
846        &mut self,
847        auth_str: &str,
848        now: &UTCTime,
849        region: &str,
850        service: &str,
851        action: &str,
852    ) -> StringToSignInfo;
853    /// calculate hash mac and update header
854    fn payload_sha256(&mut self) -> String;
855    /// calculate hash mac and update header
856    fn request_sha256(&mut self) -> RequestHashInfo;
857    fn sign(
858        &mut self,
859        auth_str: &str,
860        now: &UTCTime,
861        sign_key: &str,
862        region: &str,
863        service: &str,
864        action: &str,
865    ) -> SignatureInfo;
866}
867
868impl V4Signature for Request {
869    fn string_to_signed(
870        &mut self,
871        auth_str: &str,
872        now: &UTCTime,
873        region: &str,
874        service: &str,
875        action: &str,
876    ) -> StringToSignInfo {
877        let iso_8601_str = {
878            let mut s = now.to_rfc3339();
879            s.retain(|c| !['-', ':'].contains(&c));
880            format!("{}Z", &s[..15])
881        };
882        let headers = self.headers_mut();
883        headers.insert(
884            header::HeaderName::from_static("x-amz-date"),
885            HeaderValue::from_str(&iso_8601_str).unwrap(),
886        );
887        let RequestHashInfo {
888            signed_headers,
889            sha256,
890        } = self.request_sha256();
891        StringToSignInfo {
892            signed_headers,
893            string_to_signed: format!(
894                "{}\n{}\n{}/{}/{}/{}\n{}",
895                auth_str,
896                iso_8601_str,
897                &iso_8601_str[..8],
898                region,
899                service,
900                action,
901                sha256
902            ),
903        }
904    }
905
906    fn payload_sha256(&mut self) -> String {
907        let mut sha = sha2_256::new();
908        sha.update(
909            self.body()
910                .map(|b| b.as_bytes())
911                .unwrap_or_default()
912                .unwrap_or_default(),
913        );
914        let payload_hash = hex::encode(sha.finalize().as_slice());
915        let headers = self.headers_mut();
916        headers.insert(
917            header::HeaderName::from_static("x-amz-content-sha256"),
918            HeaderValue::from_str(&payload_hash).unwrap(),
919        );
920        payload_hash
921    }
922
923    fn request_sha256(&mut self) -> RequestHashInfo {
924        let paload_hash = self.payload_sha256();
925
926        let CanonicalRequestInfo {
927            signed_headers,
928            canonical_request,
929        } = self.canonical_request_info(&paload_hash);
930
931        let mut sha = sha2_256::new();
932        sha.update(canonical_request.as_str());
933        RequestHashInfo {
934            signed_headers,
935            sha256: hex::encode(sha.finalize().as_slice()),
936        }
937    }
938
939    fn sign(
940        &mut self,
941        auth_str: &str,
942        now: &UTCTime,
943        sign_key: &str,
944        region: &str,
945        service: &str,
946        action: &str,
947    ) -> SignatureInfo {
948        let StringToSignInfo {
949            signed_headers,
950            string_to_signed,
951        } = <Request as V4Signature>::string_to_signed(
952            self, auth_str, now, region, service, action,
953        );
954        let time_str = {
955            let mut s = now.to_rfc3339();
956            s.retain(|c| !['-', ':'].contains(&c));
957            &s[..8].to_string()
958        };
959
960        let mut key: String = auth_str.split('-').next().unwrap_or_default().to_string();
961        key.push_str(sign_key);
962
963        let mut mac = Hmac::<sha2_256>::new_from_slice(key.as_str().as_bytes())
964            .expect("HMAC can take key of any size");
965        mac.update(time_str.as_bytes());
966        let result = mac.finalize();
967        let code_bytes = result.into_bytes();
968
969        let mut mac1 =
970            Hmac::<sha2_256>::new_from_slice(&code_bytes).expect("HMAC can take key of any size");
971        mac1.update(region.as_bytes());
972        let result1 = mac1.finalize();
973        let code_bytes1 = result1.into_bytes();
974
975        let mut mac2 =
976            Hmac::<sha2_256>::new_from_slice(&code_bytes1).expect("HMAC can take key of any size");
977        mac2.update(service.as_bytes());
978        let result2 = mac2.finalize();
979        let code_bytes2 = result2.into_bytes();
980
981        let mut mac3 =
982            Hmac::<sha2_256>::new_from_slice(&code_bytes2).expect("HMAC can take key of any size");
983        mac3.update(action.as_bytes());
984        let result3 = mac3.finalize();
985        let code_bytes3 = result3.into_bytes();
986
987        let mut mac4 =
988            Hmac::<sha2_256>::new_from_slice(&code_bytes3).expect("HMAC can take key of any size");
989        mac4.update(string_to_signed.as_bytes());
990        let result4 = mac4.finalize();
991        let code_bytes4 = result4.into_bytes();
992
993        SignatureInfo {
994            signed_headers,
995            signature: format!("{code_bytes4:02x}"),
996        }
997    }
998}
999
1000#[cfg(test)]
1001mod tests {
1002    use super::*;
1003    use crate::blocking::CredentialConfig;
1004
1005    #[tokio::test]
1006    async fn test_handle_list_response() {
1007        let s = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<ListBucketResult xmlns=\"http://s3.amazonaws.com/doc/2006-03-01/\"><Name>ant-lab</Name><Prefix></Prefix><Marker></Marker><MaxKeys>1000</MaxKeys><IsTruncated>false</IsTruncated><Contents><Key>14M</Key><LastModified>2020-01-31T14:58:45.000Z</LastModified><ETag>&quot;8ff43d748637d249d80d6f45e15c7663-3&quot;</ETag><Size>14336000</Size><Owner><ID>54bbddd7c9c485b696f5b188467d4bec889b83d3862d0a6db526d9d17aadcee2</ID><DisplayName>yanganto</DisplayName></Owner><StorageClass>STANDARD</StorageClass></Contents><Contents><Key>7M</Key><LastModified>2020-11-21T09:50:46.000Z</LastModified><ETag>&quot;cbe4f29b8b099989ae49afc02aa1c618-2&quot;</ETag><Size>7168000</Size><Owner><ID>54bbddd7c9c485b696f5b188467d4bec889b83d3862d0a6db526d9d17aadcee2</ID><DisplayName>yanganto</DisplayName></Owner><StorageClass>STANDARD</StorageClass></Contents><Contents><Key>7M.json</Key><LastModified>2020-09-19T14:59:23.000Z</LastModified><ETag>&quot;d34bd3f9aff10629ac49353312a42b0f-2&quot;</ETag><Size>7168000</Size><Owner><ID>54bbddd7c9c485b696f5b188467d4bec889b83d3862d0a6db526d9d17aadcee2</ID><DisplayName>yanganto</DisplayName></Owner><StorageClass>STANDARD</StorageClass></Contents><Contents><Key>get</Key><LastModified>2020-08-11T06:10:11.000Z</LastModified><ETag>&quot;f895d74af5106ce0c3d6cb008fb3b98d&quot;</ETag><Size>304</Size><Owner><ID>54bbddd7c9c485b696f5b188467d4bec889b83d3862d0a6db526d9d17aadcee2</ID><DisplayName>yanganto</DisplayName></Owner><StorageClass>STANDARD</StorageClass></Contents><Contents><Key>t</Key><LastModified>2020-09-19T15:10:08.000Z</LastModified><ETag>&quot;5050ef3558233dc04b3fac50eff68de1&quot;</ETag><Size>10</Size><Owner><ID>54bbddd7c9c485b696f5b188467d4bec889b83d3862d0a6db526d9d17aadcee2</ID><DisplayName>yanganto</DisplayName></Owner><StorageClass>STANDARD</StorageClass></Contents><Contents><Key>t.txt</Key><LastModified>2020-09-19T15:04:46.000Z</LastModified><ETag>&quot;5050ef3558233dc04b3fac50eff68de1&quot;</ETag><Size>10</Size><Owner><ID>54bbddd7c9c485b696f5b188467d4bec889b83d3862d0a6db526d9d17aadcee2</ID><DisplayName>yanganto</DisplayName></Owner><StorageClass>STANDARD</StorageClass></Contents><Contents><Key>test-orig</Key><LastModified>2020-11-21T09:48:29.000Z</LastModified><ETag>&quot;c059dadd468de1835bc99dab6e3b2cee-3&quot;</ETag><Size>11534336</Size><Owner><ID>54bbddd7c9c485b696f5b188467d4bec889b83d3862d0a6db526d9d17aadcee2</ID><DisplayName>yanganto</DisplayName></Owner><StorageClass>STANDARD</StorageClass></Contents><Contents><Key>test-s3handle</Key><LastModified>2020-11-21T10:09:39.000Z</LastModified><ETag>&quot;5dd39cab1c53c2c77cd352983f9641e1&quot;</ETag><Size>20</Size><Owner><ID>54bbddd7c9c485b696f5b188467d4bec889b83d3862d0a6db526d9d17aadcee2</ID><DisplayName>yanganto</DisplayName></Owner><StorageClass>STANDARD</StorageClass></Contents><Contents><Key>test.json</Key><LastModified>2020-08-11T09:54:42.000Z</LastModified><ETag>&quot;f895d74af5106ce0c3d6cb008fb3b98d&quot;</ETag><Size>304</Size><Owner><ID>54bbddd7c9c485b696f5b188467d4bec889b83d3862d0a6db526d9d17aadcee2</ID><DisplayName>yanganto</DisplayName></Owner><StorageClass>STANDARD</StorageClass></Contents></ListBucketResult>";
1008        let mut pool = S3Pool::new("somewhere.in.the.world".to_string());
1009        pool.handle_list_response(s.to_string()).unwrap();
1010        assert!(!pool.objects.is_empty());
1011        assert!(!pool.is_truncated);
1012    }
1013
1014    #[test]
1015    fn test_from_blocking_handle_to_s3_pool() {
1016        let config = CredentialConfig {
1017            host: "s3.us-east-1.amazonaws.com".to_string(),
1018            access_key: "akey".to_string(),
1019            secret_key: "skey".to_string(),
1020            user: None,
1021            region: None,  // default is us-east-1
1022            s3_type: None, // default will try to config as AWS S3 handler
1023            secure: None,  // dafault is false, because the integrity protect by HMAC
1024        };
1025        let handler = Handler::from(&config);
1026        let mut pool = S3Pool::from(&handler);
1027        let s3_pool = S3Pool::new("s3.us-east-1.amazonaws.com".to_string());
1028        assert_eq!(pool.host, s3_pool.host);
1029
1030        pool = handler.into();
1031        let s3_pool = S3Pool::new("s3.us-east-1.amazonaws.com".to_string());
1032        assert_eq!(pool.host, s3_pool.host);
1033    }
1034}