Skip to main content

s3/request/
request_trait.rs

1use base64::Engine;
2use base64::engine::general_purpose;
3use hmac::Mac;
4use quick_xml::se::to_string;
5use std::collections::HashMap;
6#[cfg(any(feature = "with-tokio", feature = "with-async-std"))]
7use std::pin::Pin;
8use time::OffsetDateTime;
9use time::format_description::well_known::Rfc2822;
10use url::Url;
11
12use crate::LONG_DATETIME;
13use crate::bucket::Bucket;
14use crate::command::Command;
15use crate::error::S3Error;
16use crate::signing;
17use bytes::Bytes;
18use http::HeaderMap;
19use http::header::{
20    ACCEPT, AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE, DATE, HOST, HeaderName, RANGE,
21};
22use std::fmt::Write as _;
23
24#[cfg(any(feature = "with-tokio", feature = "with-async-std"))]
25use futures_util::Stream;
26#[cfg(any(feature = "with-tokio", feature = "with-async-std"))]
27use futures_util::stream::{self, StreamExt};
28
29#[derive(Debug)]
30
31pub struct ResponseData {
32    bytes: Bytes,
33    status_code: u16,
34    headers: HashMap<String, String>,
35}
36
37#[cfg(any(feature = "with-tokio", feature = "with-async-std"))]
38pub type DataStream = Pin<Box<dyn Stream<Item = StreamItem> + Send>>;
39#[cfg(any(feature = "with-tokio", feature = "with-async-std"))]
40pub type StreamItem = Result<Bytes, S3Error>;
41
42#[cfg(any(feature = "with-tokio", feature = "with-async-std"))]
43pub struct ResponseDataStream {
44    pub bytes: DataStream,
45    pub status_code: u16,
46}
47
48#[cfg(any(feature = "with-tokio", feature = "with-async-std"))]
49impl ResponseDataStream {
50    pub fn bytes(&mut self) -> &mut DataStream {
51        &mut self.bytes
52    }
53}
54
55impl From<ResponseData> for Vec<u8> {
56    fn from(data: ResponseData) -> Vec<u8> {
57        data.to_vec()
58    }
59}
60
61impl ResponseData {
62    pub fn new(bytes: Bytes, status_code: u16, headers: HashMap<String, String>) -> ResponseData {
63        ResponseData {
64            bytes,
65            status_code,
66            headers,
67        }
68    }
69
70    pub fn as_slice(&self) -> &[u8] {
71        &self.bytes
72    }
73
74    pub fn to_vec(self) -> Vec<u8> {
75        self.bytes.to_vec()
76    }
77
78    pub fn bytes(&self) -> &Bytes {
79        &self.bytes
80    }
81
82    pub fn bytes_mut(&mut self) -> &mut Bytes {
83        &mut self.bytes
84    }
85
86    pub fn into_bytes(self) -> Bytes {
87        self.bytes
88    }
89
90    pub fn status_code(&self) -> u16 {
91        self.status_code
92    }
93
94    pub fn as_str(&self) -> Result<&str, std::str::Utf8Error> {
95        std::str::from_utf8(self.as_slice())
96    }
97
98    pub fn to_string(&self) -> Result<String, std::str::Utf8Error> {
99        std::str::from_utf8(self.as_slice()).map(|s| s.to_string())
100    }
101
102    pub fn headers(&self) -> HashMap<String, String> {
103        self.headers.clone()
104    }
105}
106
107use std::fmt;
108
109impl fmt::Display for ResponseData {
110    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
111        write!(
112            f,
113            "Status code: {}\n Data: {}",
114            self.status_code(),
115            self.to_string()
116                .unwrap_or_else(|_| "Data could not be cast to UTF string".to_string())
117        )
118    }
119}
120
121#[cfg(feature = "with-tokio")]
122impl tokio::io::AsyncRead for ResponseDataStream {
123    fn poll_read(
124        mut self: Pin<&mut Self>,
125        cx: &mut std::task::Context<'_>,
126        buf: &mut tokio::io::ReadBuf<'_>,
127    ) -> std::task::Poll<std::io::Result<()>> {
128        if buf.remaining() == 0 {
129            return std::task::Poll::Ready(Ok(()));
130        }
131
132        loop {
133            match Stream::poll_next(self.bytes.as_mut(), cx) {
134                std::task::Poll::Ready(Some(Ok(chunk))) => {
135                    if chunk.is_empty() {
136                        continue;
137                    }
138
139                    let amt = std::cmp::min(chunk.len(), buf.remaining());
140                    buf.put_slice(&chunk[..amt]);
141
142                    if amt < chunk.len() {
143                        let remainder = chunk.slice(amt..);
144                        let previous_stream =
145                            std::mem::replace(&mut self.bytes, Box::pin(stream::empty()));
146                        self.bytes = Box::pin(
147                            stream::once(async move { Ok(remainder) }).chain(previous_stream),
148                        );
149                    }
150
151                    return std::task::Poll::Ready(Ok(()));
152                }
153                std::task::Poll::Ready(Some(Err(error))) => {
154                    return std::task::Poll::Ready(Err(std::io::Error::other(error)));
155                }
156                std::task::Poll::Ready(None) => {
157                    return std::task::Poll::Ready(Ok(()));
158                }
159                std::task::Poll::Pending => return std::task::Poll::Pending,
160            }
161        }
162    }
163}
164
165#[cfg(feature = "with-async-std")]
166impl async_std::io::Read for ResponseDataStream {
167    fn poll_read(
168        mut self: Pin<&mut Self>,
169        cx: &mut std::task::Context<'_>,
170        buf: &mut [u8],
171    ) -> std::task::Poll<std::io::Result<usize>> {
172        if buf.is_empty() {
173            return std::task::Poll::Ready(Ok(0));
174        }
175
176        loop {
177            match Stream::poll_next(self.bytes.as_mut(), cx) {
178                std::task::Poll::Ready(Some(Ok(chunk))) => {
179                    if chunk.is_empty() {
180                        continue;
181                    }
182
183                    let amt = std::cmp::min(chunk.len(), buf.len());
184                    buf[..amt].copy_from_slice(&chunk[..amt]);
185
186                    if amt < chunk.len() {
187                        let remainder = chunk.slice(amt..);
188                        let previous_stream =
189                            std::mem::replace(&mut self.bytes, Box::pin(stream::empty()));
190                        self.bytes = Box::pin(
191                            stream::once(async move { Ok(remainder) }).chain(previous_stream),
192                        );
193                    }
194
195                    return std::task::Poll::Ready(Ok(amt));
196                }
197                std::task::Poll::Ready(Some(Err(error))) => {
198                    return std::task::Poll::Ready(Err(std::io::Error::other(error)));
199                }
200                std::task::Poll::Ready(None) => {
201                    return std::task::Poll::Ready(Ok(0));
202                }
203                std::task::Poll::Pending => return std::task::Poll::Pending,
204            }
205        }
206    }
207}
208
209#[maybe_async::maybe_async]
210pub trait Request {
211    type Response;
212    type HeaderMap;
213
214    async fn response(&self) -> Result<Self::Response, S3Error>;
215    async fn response_data(&self, etag: bool) -> Result<ResponseData, S3Error>;
216    #[cfg(feature = "with-tokio")]
217    async fn response_data_to_writer<T: tokio::io::AsyncWrite + Send + Unpin + ?Sized>(
218        &self,
219        writer: &mut T,
220    ) -> Result<u16, S3Error>;
221    #[cfg(feature = "with-async-std")]
222    async fn response_data_to_writer<T: async_std::io::Write + Send + Unpin + ?Sized>(
223        &self,
224        writer: &mut T,
225    ) -> Result<u16, S3Error>;
226    #[cfg(feature = "sync")]
227    fn response_data_to_writer<T: std::io::Write + Send + ?Sized>(
228        &self,
229        writer: &mut T,
230    ) -> Result<u16, S3Error>;
231    #[cfg(any(feature = "with-async-std", feature = "with-tokio"))]
232    async fn response_data_to_stream(&self) -> Result<ResponseDataStream, S3Error>;
233    async fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error>;
234    async fn response_status(&self) -> Result<u16, S3Error> {
235        let (_, status_code) = self.response_header().await?;
236        Ok(status_code)
237    }
238    fn datetime(&self) -> OffsetDateTime;
239    fn bucket(&self) -> Bucket;
240    fn command(&self) -> Command<'_>;
241    fn path(&self) -> String;
242
243    async fn signing_key(&self) -> Result<Vec<u8>, S3Error> {
244        signing::signing_key(
245            &self.datetime(),
246            &self
247                .bucket()
248                .secret_key()
249                .await?
250                .expect("Secret key must be provided to sign headers, found None"),
251            &self.bucket().region(),
252            "s3",
253        )
254    }
255
256    fn request_body(&self) -> Result<Vec<u8>, S3Error> {
257        let result = if let Command::PutObject { content, .. } = self.command() {
258            Vec::from(content)
259        } else if let Command::PutObjectTagging { tags } = self.command() {
260            Vec::from(tags)
261        } else if let Command::UploadPart { content, .. } = self.command() {
262            Vec::from(content)
263        } else if let Command::CompleteMultipartUpload { data, .. } = &self.command() {
264            let body = data.to_string();
265            body.as_bytes().to_vec()
266        } else if let Command::CreateBucket { config } = &self.command() {
267            if let Some(payload) = config.location_constraint_payload() {
268                Vec::from(payload)
269            } else {
270                Vec::new()
271            }
272        } else if let Command::PutBucketLifecycle { configuration, .. } = &self.command() {
273            quick_xml::se::to_string(configuration)?.as_bytes().to_vec()
274        } else if let Command::PutBucketCors { configuration, .. } = &self.command() {
275            let cors = configuration.to_string();
276            cors.as_bytes().to_vec()
277        } else if let Command::DeleteObjects { data } = &self.command() {
278            data.to_string().as_bytes().to_vec()
279        } else {
280            Vec::new()
281        };
282        Ok(result)
283    }
284
285    fn long_date(&self) -> Result<String, S3Error> {
286        Ok(self.datetime().format(LONG_DATETIME)?)
287    }
288
289    fn string_to_sign(&self, request: &str) -> Result<String, S3Error> {
290        signing::string_to_sign(&self.datetime(), &self.bucket().region(), request)
291    }
292
293    fn host_header(&self) -> String {
294        self.bucket().host()
295    }
296
297    #[maybe_async::async_impl]
298    async fn presigned(&self) -> Result<String, S3Error> {
299        let (expiry, custom_headers, custom_queries) = match self.command() {
300            Command::PresignGet {
301                expiry_secs,
302                custom_queries,
303            } => (expiry_secs, None, custom_queries),
304            Command::PresignPut {
305                expiry_secs,
306                custom_headers,
307                custom_queries,
308            } => (expiry_secs, custom_headers, custom_queries),
309            Command::PresignDelete { expiry_secs } => (expiry_secs, None, None),
310            _ => unreachable!(),
311        };
312
313        let url = self
314            .presigned_url_no_sig(expiry, custom_headers.as_ref(), custom_queries.as_ref())
315            .await?;
316
317        // Build the URL string preserving the original host (including standard ports)
318        // The Url type drops standard ports when converting to string, but we need them
319        // for signature validation
320        let url_str = if let awsregion::Region::Custom { ref endpoint, .. } = self.bucket().region()
321        {
322            // Check if we need to preserve a standard port
323            if (endpoint.contains(":80") && url.scheme() == "http" && url.port().is_none())
324                || (endpoint.contains(":443") && url.scheme() == "https" && url.port().is_none())
325            {
326                // Rebuild the URL with the original host from the endpoint
327                let host = self.bucket().host();
328                format!(
329                    "{}://{}{}{}",
330                    url.scheme(),
331                    host,
332                    url.path(),
333                    url.query().map(|q| format!("?{}", q)).unwrap_or_default()
334                )
335            } else {
336                url.to_string()
337            }
338        } else {
339            url.to_string()
340        };
341
342        Ok(format!(
343            "{}&X-Amz-Signature={}",
344            url_str,
345            self.presigned_authorization(custom_headers.as_ref())
346                .await?
347        ))
348    }
349
350    #[maybe_async::sync_impl]
351    async fn presigned(&self) -> Result<String, S3Error> {
352        let (expiry, custom_headers, custom_queries) = match self.command() {
353            Command::PresignGet {
354                expiry_secs,
355                custom_queries,
356            } => (expiry_secs, None, custom_queries),
357            Command::PresignPut {
358                expiry_secs,
359                custom_headers,
360                ..
361            } => (expiry_secs, custom_headers, None),
362            Command::PresignDelete { expiry_secs } => (expiry_secs, None, None),
363            _ => unreachable!(),
364        };
365
366        let url =
367            self.presigned_url_no_sig(expiry, custom_headers.as_ref(), custom_queries.as_ref())?;
368
369        // Build the URL string preserving the original host (including standard ports)
370        // The Url type drops standard ports when converting to string, but we need them
371        // for signature validation
372        let url_str = if let awsregion::Region::Custom { ref endpoint, .. } = self.bucket().region()
373        {
374            // Check if we need to preserve a standard port
375            if (endpoint.contains(":80") && url.scheme() == "http" && url.port().is_none())
376                || (endpoint.contains(":443") && url.scheme() == "https" && url.port().is_none())
377            {
378                // Rebuild the URL with the original host from the endpoint
379                let host = self.bucket().host();
380                format!(
381                    "{}://{}{}{}",
382                    url.scheme(),
383                    host,
384                    url.path(),
385                    url.query().map(|q| format!("?{}", q)).unwrap_or_default()
386                )
387            } else {
388                url.to_string()
389            }
390        } else {
391            url.to_string()
392        };
393
394        Ok(format!(
395            "{}&X-Amz-Signature={}",
396            url_str,
397            self.presigned_authorization(custom_headers.as_ref())?
398        ))
399    }
400
401    async fn presigned_authorization(
402        &self,
403        custom_headers: Option<&HeaderMap>,
404    ) -> Result<String, S3Error> {
405        let mut headers = HeaderMap::new();
406        let host_header = self.host_header();
407        headers.insert(HOST, host_header.parse()?);
408        if let Some(custom_headers) = custom_headers {
409            for (k, v) in custom_headers.iter() {
410                headers.insert(k.clone(), v.clone());
411            }
412        }
413        let canonical_request = self.presigned_canonical_request(&headers).await?;
414        let string_to_sign = self.string_to_sign(&canonical_request)?;
415        let mut hmac = signing::HmacSha256::new_from_slice(&self.signing_key().await?)?;
416        hmac.update(string_to_sign.as_bytes());
417        let signature = hex::encode(hmac.finalize().into_bytes());
418        // let signed_header = signing::signed_header_string(&headers);
419        Ok(signature)
420    }
421
422    async fn presigned_canonical_request(&self, headers: &HeaderMap) -> Result<String, S3Error> {
423        let (expiry, custom_headers, custom_queries) = match self.command() {
424            Command::PresignGet {
425                expiry_secs,
426                custom_queries,
427            } => (expiry_secs, None, custom_queries),
428            Command::PresignPut {
429                expiry_secs,
430                custom_headers,
431                custom_queries,
432            } => (expiry_secs, custom_headers, custom_queries),
433            Command::PresignDelete { expiry_secs } => (expiry_secs, None, None),
434            _ => unreachable!(),
435        };
436
437        signing::canonical_request(
438            &self.command().http_verb().to_string(),
439            &self
440                .presigned_url_no_sig(expiry, custom_headers.as_ref(), custom_queries.as_ref())
441                .await?,
442            headers,
443            "UNSIGNED-PAYLOAD",
444        )
445    }
446
447    #[maybe_async::async_impl]
448    async fn presigned_url_no_sig(
449        &self,
450        expiry: u32,
451        custom_headers: Option<&HeaderMap>,
452        custom_queries: Option<&HashMap<String, String>>,
453    ) -> Result<Url, S3Error> {
454        let bucket = self.bucket();
455        let token = if let Some(security_token) = bucket.security_token().await? {
456            Some(security_token)
457        } else {
458            bucket.session_token().await?
459        };
460        let url = Url::parse(&format!(
461            "{}{}{}",
462            self.url()?,
463            &signing::authorization_query_params_no_sig(
464                &self.bucket().access_key().await?.unwrap_or_default(),
465                &self.datetime(),
466                &self.bucket().region(),
467                expiry,
468                custom_headers,
469                token.as_ref()
470            )?,
471            &signing::flatten_queries(custom_queries)?,
472        ))?;
473
474        Ok(url)
475    }
476
477    #[maybe_async::sync_impl]
478    fn presigned_url_no_sig(
479        &self,
480        expiry: u32,
481        custom_headers: Option<&HeaderMap>,
482        custom_queries: Option<&HashMap<String, String>>,
483    ) -> Result<Url, S3Error> {
484        let bucket = self.bucket();
485        let token = if let Some(security_token) = bucket.security_token()? {
486            Some(security_token)
487        } else {
488            bucket.session_token()?
489        };
490        let url = Url::parse(&format!(
491            "{}{}{}",
492            self.url()?,
493            &signing::authorization_query_params_no_sig(
494                &self.bucket().access_key()?.unwrap_or_default(),
495                &self.datetime(),
496                &self.bucket().region(),
497                expiry,
498                custom_headers,
499                token.as_ref()
500            )?,
501            &signing::flatten_queries(custom_queries)?,
502        ))?;
503
504        Ok(url)
505    }
506
507    fn url(&self) -> Result<Url, S3Error> {
508        let mut url_str = self.bucket().url();
509
510        if let Command::ListBuckets { .. } = self.command() {
511            return Ok(Url::parse(&url_str)?);
512        }
513
514        if let Command::CreateBucket { .. } = self.command() {
515            return Ok(Url::parse(&url_str)?);
516        }
517
518        let path = if self.path().starts_with('/') {
519            self.path()[1..].to_string()
520        } else {
521            self.path()[..].to_string()
522        };
523
524        url_str.push('/');
525        url_str.push_str(&signing::uri_encode(&path, false));
526
527        // Append to url_path
528        #[allow(clippy::collapsible_match)]
529        match self.command() {
530            Command::InitiateMultipartUpload { .. } | Command::ListMultipartUploads { .. } => {
531                url_str.push_str("?uploads")
532            }
533            Command::AbortMultipartUpload { upload_id } => {
534                write!(url_str, "?uploadId={}", upload_id).expect("Could not write to url_str");
535            }
536            Command::CompleteMultipartUpload { upload_id, .. } => {
537                write!(url_str, "?uploadId={}", upload_id).expect("Could not write to url_str");
538            }
539            Command::GetObjectTorrent => url_str.push_str("?torrent"),
540            Command::PutObject { multipart, .. } => {
541                if let Some(multipart) = multipart {
542                    url_str.push_str(&multipart.query_string())
543                }
544            }
545            Command::GetBucketLifecycle
546            | Command::PutBucketLifecycle { .. }
547            | Command::DeleteBucketLifecycle => {
548                url_str.push_str("?lifecycle");
549            }
550            Command::GetBucketCors { .. }
551            | Command::PutBucketCors { .. }
552            | Command::DeleteBucketCors { .. } => {
553                url_str.push_str("?cors");
554            }
555            Command::GetObjectAttributes { version_id, .. } => {
556                if let Some(version_id) = version_id {
557                    url_str.push_str(&format!("?attributes&versionId={}", version_id));
558                } else {
559                    url_str.push_str("?attributes&versionId=null");
560                }
561            }
562            Command::HeadObject => {}
563            Command::DeleteObject => {}
564            Command::DeleteObjectTagging => {}
565            Command::GetObject => {}
566            Command::GetObjectRange { .. } => {}
567            Command::GetObjectTagging => {}
568            Command::ListObjects { .. } => {}
569            Command::ListObjectsV2 { .. } => {}
570            Command::GetBucketLocation => {}
571            Command::PresignGet { .. } => {}
572            Command::PresignPut { .. } => {}
573            Command::PresignDelete { .. } => {}
574            Command::DeleteBucket => {}
575            Command::ListBuckets => {}
576            Command::CopyObject { .. } => {}
577            Command::PutObjectTagging { .. } => {}
578            Command::UploadPart { .. } => {}
579            Command::CreateBucket { .. } => {}
580            Command::DeleteObjects { .. } => {
581                url_str.push_str("?delete");
582            }
583        }
584
585        let mut url = Url::parse(&url_str)?;
586
587        for (key, value) in &self.bucket().extra_query {
588            url.query_pairs_mut().append_pair(key, value);
589        }
590
591        if let Command::ListObjectsV2 {
592            prefix,
593            delimiter,
594            continuation_token,
595            start_after,
596            max_keys,
597        } = self.command().clone()
598        {
599            let mut query_pairs = url.query_pairs_mut();
600            delimiter.map(|d| query_pairs.append_pair("delimiter", &d));
601
602            query_pairs.append_pair("prefix", &prefix);
603            query_pairs.append_pair("list-type", "2");
604            if let Some(token) = continuation_token {
605                query_pairs.append_pair("continuation-token", &token);
606            }
607            if let Some(start_after) = start_after {
608                query_pairs.append_pair("start-after", &start_after);
609            }
610            if let Some(max_keys) = max_keys {
611                query_pairs.append_pair("max-keys", &max_keys.to_string());
612            }
613        }
614
615        if let Command::ListObjects {
616            prefix,
617            delimiter,
618            marker,
619            max_keys,
620        } = self.command().clone()
621        {
622            let mut query_pairs = url.query_pairs_mut();
623            delimiter.map(|d| query_pairs.append_pair("delimiter", &d));
624
625            query_pairs.append_pair("prefix", &prefix);
626            if let Some(marker) = marker {
627                query_pairs.append_pair("marker", &marker);
628            }
629            if let Some(max_keys) = max_keys {
630                query_pairs.append_pair("max-keys", &max_keys.to_string());
631            }
632        }
633
634        match self.command() {
635            Command::ListMultipartUploads {
636                prefix,
637                delimiter,
638                key_marker,
639                max_uploads,
640            } => {
641                let mut query_pairs = url.query_pairs_mut();
642                delimiter.map(|d| query_pairs.append_pair("delimiter", d));
643                if let Some(prefix) = prefix {
644                    query_pairs.append_pair("prefix", prefix);
645                }
646                if let Some(key_marker) = key_marker {
647                    query_pairs.append_pair("key-marker", &key_marker);
648                }
649                if let Some(max_uploads) = max_uploads {
650                    query_pairs.append_pair("max-uploads", max_uploads.to_string().as_str());
651                }
652            }
653            Command::PutObjectTagging { .. }
654            | Command::GetObjectTagging
655            | Command::DeleteObjectTagging => {
656                url.query_pairs_mut().append_pair("tagging", "");
657            }
658            _ => {}
659        }
660
661        Ok(url)
662    }
663
664    fn canonical_request(&self, headers: &HeaderMap) -> Result<String, S3Error> {
665        signing::canonical_request(
666            &self.command().http_verb().to_string(),
667            &self.url()?,
668            headers,
669            &self.command().sha256()?,
670        )
671    }
672
673    #[maybe_async::maybe_async]
674    async fn authorization(&self, headers: &HeaderMap) -> Result<String, S3Error> {
675        let canonical_request = self.canonical_request(headers)?;
676        let string_to_sign = self.string_to_sign(&canonical_request)?;
677        let mut hmac = signing::HmacSha256::new_from_slice(&self.signing_key().await?)?;
678        hmac.update(string_to_sign.as_bytes());
679        let signature = hex::encode(hmac.finalize().into_bytes());
680        let signed_header = signing::signed_header_string(headers);
681        signing::authorization_header(
682            &self
683                .bucket()
684                .access_key()
685                .await?
686                .expect("No access_key provided"),
687            &self.datetime(),
688            &self.bucket().region(),
689            &signed_header,
690            &signature,
691        )
692    }
693
694    #[maybe_async::maybe_async]
695    async fn headers(&self) -> Result<HeaderMap, S3Error> {
696        // Generate this once, but it's used in more than one place.
697        let sha256 = self.command().sha256()?;
698
699        // Start with extra_headers, that way our headers replace anything with
700        // the same name.
701
702        let mut headers = HeaderMap::new();
703
704        for (k, v) in self.bucket().extra_headers.iter() {
705            if k.as_str().starts_with("x-amz-meta-") {
706                // metadata is invalid on any multipart command other than initiate
707                match self.command() {
708                    Command::UploadPart { .. }
709                    | Command::AbortMultipartUpload { .. }
710                    | Command::CompleteMultipartUpload { .. }
711                    | Command::PutObject {
712                        multipart: Some(_), ..
713                    } => continue,
714                    _ => (),
715                }
716            }
717            headers.insert(k.clone(), v.clone());
718        }
719
720        // Append custom headers for PUT request if any
721        if let Command::PutObject { custom_headers, .. } = self.command()
722            && let Some(custom_headers) = custom_headers
723        {
724            for (k, v) in custom_headers.iter() {
725                headers.insert(k.clone(), v.clone());
726            }
727        }
728
729        let host_header = self.host_header();
730
731        headers.insert(HOST, host_header.parse()?);
732
733        match self.command() {
734            Command::CopyObject { from } => {
735                headers.insert(HeaderName::from_static("x-amz-copy-source"), from.parse()?);
736            }
737            Command::ListObjects { .. } => {}
738            Command::ListObjectsV2 { .. } => {}
739            Command::HeadObject => {}
740            Command::GetObject => {}
741            Command::GetObjectTagging => {}
742            Command::GetBucketLocation => {}
743            Command::ListBuckets => {}
744            _ => {
745                headers.insert(
746                    CONTENT_LENGTH,
747                    self.command().content_length()?.to_string().parse()?,
748                );
749                headers.insert(CONTENT_TYPE, self.command().content_type().parse()?);
750            }
751        }
752        headers.insert(
753            HeaderName::from_static("x-amz-content-sha256"),
754            sha256.parse()?,
755        );
756        headers.insert(
757            HeaderName::from_static("x-amz-date"),
758            self.long_date()?.parse()?,
759        );
760
761        if let Some(session_token) = self.bucket().session_token().await? {
762            headers.insert(
763                HeaderName::from_static("x-amz-security-token"),
764                session_token.parse()?,
765            );
766        } else if let Some(security_token) = self.bucket().security_token().await? {
767            headers.insert(
768                HeaderName::from_static("x-amz-security-token"),
769                security_token.parse()?,
770            );
771        }
772
773        if let Command::PutObjectTagging { tags } = self.command() {
774            let digest = md5::compute(tags);
775            let hash = general_purpose::STANDARD.encode(digest.as_ref());
776            headers.insert(HeaderName::from_static("content-md5"), hash.parse()?);
777        } else if let Command::PutObject { content, .. } = self.command() {
778            let digest = md5::compute(content);
779            let hash = general_purpose::STANDARD.encode(digest.as_ref());
780            headers.insert(HeaderName::from_static("content-md5"), hash.parse()?);
781        } else if let Command::UploadPart { content, .. } = self.command() {
782            let digest = md5::compute(content);
783            let hash = general_purpose::STANDARD.encode(digest.as_ref());
784            headers.insert(HeaderName::from_static("content-md5"), hash.parse()?);
785        } else if let Command::GetObject {} = self.command() {
786            headers.insert(ACCEPT, "application/octet-stream".to_string().parse()?);
787        // headers.insert(header::ACCEPT_CHARSET, HeaderValue::from_str("UTF-8")?);
788        } else if let Command::GetObjectRange { start, end } = self.command() {
789            headers.insert(ACCEPT, "application/octet-stream".to_string().parse()?);
790
791            let mut range = format!("bytes={}-", start);
792
793            if let Some(end) = end {
794                range.push_str(&end.to_string());
795            }
796
797            headers.insert(RANGE, range.parse()?);
798        } else if let Command::CreateBucket { ref config } = self.command() {
799            config.add_headers(&mut headers)?;
800        } else if let Command::PutBucketLifecycle { ref configuration } = self.command() {
801            let digest = md5::compute(to_string(configuration)?.as_bytes());
802            let hash = general_purpose::STANDARD.encode(digest.as_ref());
803            headers.insert(HeaderName::from_static("content-md5"), hash.parse()?);
804        } else if let Command::PutBucketCors {
805            expected_bucket_owner,
806            configuration,
807            ..
808        } = self.command()
809        {
810            let digest = md5::compute(configuration.to_string().as_bytes());
811            let hash = general_purpose::STANDARD.encode(digest.as_ref());
812            headers.insert(HeaderName::from_static("content-md5"), hash.parse()?);
813
814            headers.insert(
815                HeaderName::from_static("x-amz-expected-bucket-owner"),
816                expected_bucket_owner.parse()?,
817            );
818        } else if let Command::GetBucketCors {
819            expected_bucket_owner,
820        } = self.command()
821        {
822            headers.insert(
823                HeaderName::from_static("x-amz-expected-bucket-owner"),
824                expected_bucket_owner.parse()?,
825            );
826        } else if let Command::DeleteBucketCors {
827            expected_bucket_owner,
828        } = self.command()
829        {
830            headers.insert(
831                HeaderName::from_static("x-amz-expected-bucket-owner"),
832                expected_bucket_owner.parse()?,
833            );
834        } else if let Command::GetObjectAttributes {
835            expected_bucket_owner,
836            ..
837        } = self.command()
838        {
839            headers.insert(
840                HeaderName::from_static("x-amz-expected-bucket-owner"),
841                expected_bucket_owner.parse()?,
842            );
843            headers.insert(
844                HeaderName::from_static("x-amz-object-attributes"),
845                "ETag".parse()?,
846            );
847        } else if let Command::DeleteObjects { ref data } = self.command() {
848            let body = data.to_string();
849            let digest = md5::compute(body.as_bytes());
850            let hash = general_purpose::STANDARD.encode(digest.as_ref());
851            headers.insert(HeaderName::from_static("content-md5"), hash.parse()?);
852        }
853
854        // This must be last, as it signs the other headers, omitted if no secret key is provided
855        if self.bucket().secret_key().await?.is_some() {
856            let authorization = self.authorization(&headers).await?;
857            headers.insert(AUTHORIZATION, authorization.parse()?);
858        }
859
860        // The format of RFC2822 is somewhat malleable, so including it in
861        // signed headers can cause signature mismatches. We do include the
862        // X-Amz-Date header, so requests are still properly limited to a date
863        // range and can't be used again e.g. reply attacks. Adding this header
864        // after the generation of the Authorization header leaves it out of
865        // the signed headers.
866        headers.insert(DATE, self.datetime().format(&Rfc2822)?.parse()?);
867
868        Ok(headers)
869    }
870}
871
872#[cfg(all(test, feature = "with-tokio"))]
873mod tests {
874    use super::*;
875    use bytes::Bytes;
876    use futures_util::stream;
877    use tokio::io::AsyncReadExt;
878
879    #[tokio::test]
880    async fn test_async_read_implementation() {
881        // Create a mock stream with test data
882        let chunks = vec![
883            Ok(Bytes::from("Hello, ")),
884            Ok(Bytes::from("World!")),
885            Ok(Bytes::from(" This is a test.")),
886        ];
887
888        let stream = stream::iter(chunks);
889        let data_stream: DataStream = Box::pin(stream);
890
891        let mut response_stream = ResponseDataStream {
892            bytes: data_stream,
893            status_code: 200,
894        };
895
896        // Read all data using AsyncRead
897        let mut buffer = Vec::new();
898        response_stream.read_to_end(&mut buffer).await.unwrap();
899
900        assert_eq!(buffer, b"Hello, World! This is a test.");
901    }
902
903    #[tokio::test]
904    async fn test_async_read_with_small_buffer_preserves_large_chunk() {
905        let expected = b"This is a much longer string that won't fit in a small buffer";
906        let chunks = vec![Ok(Bytes::copy_from_slice(expected))];
907
908        let stream = stream::iter(chunks);
909        let data_stream: DataStream = Box::pin(stream);
910
911        let mut response_stream = ResponseDataStream {
912            bytes: data_stream,
913            status_code: 200,
914        };
915
916        let mut output = Vec::new();
917        let mut buffer = [0u8; 10];
918
919        loop {
920            let n = response_stream.read(&mut buffer).await.unwrap();
921            if n == 0 {
922                break;
923            }
924            output.extend_from_slice(&buffer[..n]);
925        }
926
927        assert_eq!(output, expected);
928    }
929
930    #[tokio::test]
931    async fn test_async_read_with_error() {
932        use crate::error::S3Error;
933
934        // Create a stream that returns an error
935        let chunks: Vec<Result<Bytes, S3Error>> = vec![
936            Ok(Bytes::from("Some data")),
937            Err(S3Error::Io(std::io::Error::new(
938                std::io::ErrorKind::Other,
939                "Test error",
940            ))),
941        ];
942
943        let stream = stream::iter(chunks);
944        let data_stream: DataStream = Box::pin(stream);
945
946        let mut response_stream = ResponseDataStream {
947            bytes: data_stream,
948            status_code: 200,
949        };
950
951        // First read should succeed
952        let mut buffer = [0u8; 20];
953        let n = response_stream.read(&mut buffer).await.unwrap();
954        assert_eq!(n, 9);
955        assert_eq!(&buffer[..n], b"Some data");
956
957        // Second read should fail with an error
958        let result = response_stream.read(&mut buffer).await;
959        assert!(result.is_err());
960    }
961
962    #[tokio::test]
963    async fn test_async_read_copy() {
964        // Test using tokio::io::copy which is a common use case
965        let chunks = vec![
966            Ok(Bytes::from("First chunk\n")),
967            Ok(Bytes::from("Second chunk\n")),
968            Ok(Bytes::from("Third chunk\n")),
969        ];
970
971        let stream = stream::iter(chunks);
972        let data_stream: DataStream = Box::pin(stream);
973
974        let mut response_stream = ResponseDataStream {
975            bytes: data_stream,
976            status_code: 200,
977        };
978
979        let mut output = Vec::new();
980        tokio::io::copy(&mut response_stream, &mut output)
981            .await
982            .unwrap();
983
984        assert_eq!(output, b"First chunk\nSecond chunk\nThird chunk\n");
985    }
986}
987
988#[cfg(all(test, feature = "with-async-std"))]
989mod async_std_tests {
990    use super::*;
991    use async_std::io::ReadExt;
992    use bytes::Bytes;
993    use futures_util::stream;
994
995    #[async_std::test]
996    async fn test_async_read_implementation() {
997        // Create a mock stream with test data
998        let chunks = vec![
999            Ok(Bytes::from("Hello, ")),
1000            Ok(Bytes::from("World!")),
1001            Ok(Bytes::from(" This is a test.")),
1002        ];
1003
1004        let stream = stream::iter(chunks);
1005        let data_stream: DataStream = Box::pin(stream);
1006
1007        let mut response_stream = ResponseDataStream {
1008            bytes: data_stream,
1009            status_code: 200,
1010        };
1011
1012        // Read all data using AsyncRead
1013        let mut buffer = Vec::new();
1014        response_stream.read_to_end(&mut buffer).await.unwrap();
1015
1016        assert_eq!(buffer, b"Hello, World! This is a test.");
1017    }
1018
1019    #[async_std::test]
1020    async fn test_async_read_with_small_buffer_preserves_large_chunk() {
1021        let expected = b"This is a much longer string that won't fit in a small buffer";
1022        let chunks = vec![Ok(Bytes::copy_from_slice(expected))];
1023
1024        let stream = stream::iter(chunks);
1025        let data_stream: DataStream = Box::pin(stream);
1026
1027        let mut response_stream = ResponseDataStream {
1028            bytes: data_stream,
1029            status_code: 200,
1030        };
1031
1032        let mut output = Vec::new();
1033        let mut buffer = [0u8; 10];
1034
1035        loop {
1036            let n = response_stream.read(&mut buffer).await.unwrap();
1037            if n == 0 {
1038                break;
1039            }
1040            output.extend_from_slice(&buffer[..n]);
1041        }
1042
1043        assert_eq!(output, expected);
1044    }
1045
1046    #[async_std::test]
1047    async fn test_async_read_with_error() {
1048        use crate::error::S3Error;
1049
1050        // Create a stream that returns an error
1051        let chunks: Vec<Result<Bytes, S3Error>> = vec![
1052            Ok(Bytes::from("Some data")),
1053            Err(S3Error::Io(std::io::Error::new(
1054                std::io::ErrorKind::Other,
1055                "Test error",
1056            ))),
1057        ];
1058
1059        let stream = stream::iter(chunks);
1060        let data_stream: DataStream = Box::pin(stream);
1061
1062        let mut response_stream = ResponseDataStream {
1063            bytes: data_stream,
1064            status_code: 200,
1065        };
1066
1067        // First read should succeed
1068        let mut buffer = [0u8; 20];
1069        let n = response_stream.read(&mut buffer).await.unwrap();
1070        assert_eq!(n, 9);
1071        assert_eq!(&buffer[..n], b"Some data");
1072
1073        // Second read should fail with an error
1074        let result = response_stream.read(&mut buffer).await;
1075        assert!(result.is_err());
1076    }
1077
1078    #[async_std::test]
1079    async fn test_async_read_copy() {
1080        // Test using async_std::io::copy which is a common use case
1081        let chunks = vec![
1082            Ok(Bytes::from("First chunk\n")),
1083            Ok(Bytes::from("Second chunk\n")),
1084            Ok(Bytes::from("Third chunk\n")),
1085        ];
1086
1087        let stream = stream::iter(chunks);
1088        let data_stream: DataStream = Box::pin(stream);
1089
1090        let mut response_stream = ResponseDataStream {
1091            bytes: data_stream,
1092            status_code: 200,
1093        };
1094
1095        let mut output = Vec::new();
1096        async_std::io::copy(&mut response_stream, &mut output)
1097            .await
1098            .unwrap();
1099
1100        assert_eq!(output, b"First chunk\nSecond chunk\nThird chunk\n");
1101    }
1102}