1use std::sync::Arc;
4
5use bytes::Bytes;
6use chrono::DateTime;
7use chrono::Utc;
8use http_cache_stream_reqwest::Cache;
9use http_cache_stream_reqwest::storage::DefaultCacheStorage;
10use reqwest::Body;
11use reqwest::Request;
12use reqwest::Response;
13use reqwest::StatusCode;
14use reqwest::header;
15use reqwest::header::HeaderValue;
16use secrecy::ExposeSecret;
17use serde::Deserialize;
18use serde::Serialize;
19use tokio::sync::broadcast;
20use tracing::debug;
21use url::Url;
22
23use crate::BLOCK_SIZE_THRESHOLD;
24use crate::Config;
25use crate::Error;
26use crate::HttpClient;
27use crate::ONE_MEBIBYTE;
28use crate::Result;
29use crate::S3AuthConfig;
30use crate::TransferEvent;
31use crate::USER_AGENT;
32use crate::UrlExt as _;
33use crate::backend::StorageBackend;
34use crate::backend::Upload;
35use crate::backend::auth::RequestSigner;
36use crate::backend::auth::SignatureProvider;
37use crate::backend::auth::sha256_hex_string;
38use crate::streams::ByteStream;
39use crate::streams::TransferStream;
40
41const AWS_ROOT_DOMAIN: &str = "amazonaws.com";
43
44const LOCALSTACK_ROOT_DOMAIN: &str = "localhost.localstack.cloud";
46
47const DEFAULT_REGION: &str = "us-east-1";
49
50const MAX_PARTS: u64 = 10000;
52
53const MIN_PART_SIZE: u64 = 5 * ONE_MEBIBYTE;
56
57const MAX_PART_SIZE: u64 = MIN_PART_SIZE * 1024;
59
60const MAX_FILE_SIZE: u64 = MAX_PART_SIZE * 1024;
62
63const AWS_DATE_HEADER: &str = "x-amz-date";
65
66const AWS_CONTENT_SHA256_HEADER: &str = "x-amz-content-sha256";
68
69#[derive(Debug, thiserror::Error)]
71pub enum S3Error {
72 #[error("S3 block size cannot exceed {MAX_PART_SIZE} bytes")]
74 InvalidBlockSize,
75 #[error("the size of the source file exceeds the supported maximum of {MAX_FILE_SIZE} bytes")]
77 MaximumSizeExceeded,
78 #[error("invalid URL with `s3` scheme: the URL is not in a supported format")]
80 InvalidScheme,
81 #[error("URL is missing the bucket in the path")]
83 MissingBucket,
84 #[error("invalid S3 secret access key")]
86 InvalidSecretAccessKey,
87 #[error("response from server was missing an ETag header")]
89 ResponseMissingETag,
90 #[error("the bucket name specified in the URL is invalid")]
92 InvalidBucketName,
93 #[error("unexpected {status} response from server: failed to deserialize response contents: {error}", status = .status.as_u16())]
95 UnexpectedResponse {
96 status: reqwest::StatusCode,
98 error: serde_xml_rs::Error,
100 },
101}
102
103#[derive(Debug, Deserialize)]
105pub struct Content {
106 #[serde(rename = "Key")]
108 pub key: String,
109}
110
111#[derive(Debug, Deserialize)]
113#[serde(rename = "ListBucketResult")]
114pub struct ListBucketResult {
115 #[serde(default, rename = "Contents")]
117 pub contents: Vec<Content>,
118 #[serde(rename = "NextContinuationToken", default)]
120 pub token: Option<String>,
121}
122
123#[derive(Default, Deserialize)]
125#[serde(rename = "InitiateMultipartUploadResult")]
126pub struct InitiateMultipartUploadResult {
127 #[serde(rename = "UploadId")]
129 pub upload_id: String,
130}
131
132pub struct S3SignatureProvider<'a> {
134 region: &'a str,
136 auth: &'a S3AuthConfig,
138}
139
140impl SignatureProvider for S3SignatureProvider<'_> {
141 fn algorithm(&self) -> &str {
142 "AWS4-HMAC-SHA256"
143 }
144
145 fn secret_key_prefix(&self) -> &str {
146 "AWS4"
147 }
148
149 fn request_type(&self) -> &str {
150 "aws4_request"
151 }
152
153 fn region(&self) -> &str {
154 self.region
155 }
156
157 fn service(&self) -> &str {
158 "s3"
159 }
160
161 fn date_header_name(&self) -> &str {
162 AWS_DATE_HEADER
163 }
164
165 fn content_hash_header_name(&self) -> &str {
166 AWS_CONTENT_SHA256_HEADER
167 }
168
169 fn access_key_id(&self) -> &str {
170 &self.auth.access_key_id
171 }
172
173 fn secret_access_key(&self) -> &str {
174 self.auth.secret_access_key.expose_secret()
175 }
176}
177
178fn append_authentication_header(
180 auth: &S3AuthConfig,
181 date: DateTime<Utc>,
182 request: &mut Request,
183) -> Result<()> {
184 let signer = RequestSigner::new(S3SignatureProvider {
185 region: request.url().region(),
186 auth,
187 });
188 let auth = signer
189 .sign(date, request)
190 .ok_or(S3Error::InvalidSecretAccessKey)?;
191 request.headers_mut().append(
192 header::AUTHORIZATION,
193 HeaderValue::try_from(auth).expect("value should be valid"),
194 );
195 Ok(())
196}
197
198trait UrlExt {
200 fn region(&self) -> &str;
206
207 fn bucket_and_path(&self) -> (&str, &str);
213}
214
215impl UrlExt for Url {
216 fn region(&self) -> &str {
217 let domain = self.domain().expect("URL should have domain");
218
219 if domain.starts_with("s3.") || domain.starts_with("S3.") {
220 let mut parts = domain.splitn(3, '.');
222 match (parts.next(), parts.next()) {
223 (_, Some(region)) => region,
224 _ => panic!("invalid S3 URL"),
225 }
226 } else {
227 let mut parts = domain.splitn(4, '.');
229
230 match (parts.next(), parts.next(), parts.next()) {
231 (_, _, Some(region)) => region,
232 _ => panic!("invalid S3 URL"),
233 }
234 }
235 }
236
237 fn bucket_and_path(&self) -> (&str, &str) {
238 let domain = self.domain().expect("URL should have domain");
239
240 if domain.starts_with("s3.") || domain.starts_with("S3.") {
241 let bucket = self
243 .path_segments()
244 .expect("URL should have path")
245 .next()
246 .expect("URL should have at least one path segment");
247
248 (
249 bucket,
250 self.path()
251 .strip_prefix('/')
252 .unwrap()
253 .strip_prefix(bucket)
254 .unwrap(),
255 )
256 } else {
257 let Some((bucket, _)) = domain.split_once('.') else {
259 panic!("URL domain does not contain a bucket");
260 };
261
262 (bucket, self.path())
263 }
264 }
265}
266
267trait ResponseExt {
269 async fn into_error(self) -> Error;
271}
272
273impl ResponseExt for Response {
274 async fn into_error(self) -> Error {
275 #[derive(Default, Deserialize)]
277 #[serde(rename = "Error")]
278 struct ErrorResponse {
279 #[serde(rename = "Message")]
281 message: String,
282 }
283
284 let status = self.status();
285
286 if status == StatusCode::MOVED_PERMANENTLY {
288 return Error::Server {
289 status,
290 message: "the AWS region being used may not be the correct region for the storage \
291 bucket"
292 .into(),
293 };
294 }
295
296 let text: String = match self.text().await {
297 Ok(text) => text,
298 Err(e) => return e.into(),
299 };
300
301 if text.is_empty() {
302 return Error::Server {
303 status,
304 message: text,
305 };
306 }
307
308 let message = match serde_xml_rs::from_str::<ErrorResponse>(&text) {
309 Ok(response) => response.message,
310 Err(e) => {
311 return S3Error::UnexpectedResponse { status, error: e }.into();
312 }
313 };
314
315 Error::Server { status, message }
316 }
317}
318
319#[derive(Default, Clone, Serialize)]
321#[serde(rename = "Part")]
322pub struct S3UploadPart {
323 #[serde(rename = "PartNumber")]
325 number: u64,
326 #[serde(rename = "ETag")]
328 etag: String,
329}
330
331pub struct S3Upload {
333 config: Arc<Config>,
335 client: HttpClient,
337 url: Url,
339 id: String,
341 events: Option<broadcast::Sender<TransferEvent>>,
343}
344
345impl Upload for S3Upload {
346 type Part = S3UploadPart;
347
348 async fn put(&self, id: u64, block: u64, bytes: Bytes) -> Result<Self::Part> {
349 debug!(
352 "sending PUT request for block {block} of `{url}`",
353 url = self.url.display()
354 );
355
356 let mut url = self.url.clone();
357
358 {
359 let mut pairs = url.query_pairs_mut();
360 pairs.append_pair("partNumber", &format!("{number}", number = block + 1));
361 pairs.append_pair("uploadId", &self.id);
362 }
363
364 let digest = sha256_hex_string(&bytes);
365 let length = bytes.len();
366 let body = Body::wrap_stream(TransferStream::new(
367 ByteStream::new(bytes),
368 id,
369 block,
370 0,
371 self.events.clone(),
372 ));
373
374 let date = Utc::now();
375 let mut request = self
376 .client
377 .put(url)
378 .header(header::USER_AGENT, USER_AGENT)
379 .header(header::CONTENT_LENGTH, length)
380 .header(header::CONTENT_TYPE, "application/octet-stream")
381 .header(AWS_DATE_HEADER, date.format("%Y%m%dT%H%M%SZ").to_string())
382 .header(AWS_CONTENT_SHA256_HEADER, &digest)
383 .body(body)
384 .build()?;
385
386 if let Some(auth) = &self.config.s3.auth {
387 append_authentication_header(auth, date, &mut request)?;
388 }
389
390 let response = self.client.execute(request).await?;
391 if !response.status().is_success() {
392 return Err(response.into_error().await);
393 }
394
395 let etag = response
396 .headers()
397 .get(header::ETAG)
398 .and_then(|v| v.to_str().ok())
399 .ok_or(S3Error::ResponseMissingETag)?;
400
401 Ok(S3UploadPart {
402 number: block + 1,
403 etag: etag.to_string(),
404 })
405 }
406
407 async fn finalize(&self, parts: &[Self::Part]) -> Result<()> {
408 #[derive(Serialize)]
412 #[serde(rename = "CompleteMultipartUpload")]
413 struct CompleteUpload<'a> {
414 #[serde(rename = "Part")]
416 parts: &'a [S3UploadPart],
417 }
418
419 debug!(
420 "sending POST request to finalize upload of `{url}`",
421 url = self.url.display()
422 );
423
424 let mut url = self.url.clone();
425
426 {
427 let mut pairs = url.query_pairs_mut();
428 pairs.append_pair("uploadId", &self.id);
429 }
430
431 let body = serde_xml_rs::SerdeXml::new()
432 .default_namespace("http://s3.amazonaws.com/doc/2006-03-01/")
433 .to_string(&CompleteUpload { parts })
434 .expect("should serialize");
435
436 let date = Utc::now();
437 let mut request = self
438 .client
439 .post(url)
440 .header(header::USER_AGENT, USER_AGENT)
441 .header(header::CONTENT_LENGTH, body.len())
442 .header(header::CONTENT_TYPE, "application/xml")
443 .header(AWS_DATE_HEADER, date.format("%Y%m%dT%H%M%SZ").to_string())
444 .header(AWS_CONTENT_SHA256_HEADER, sha256_hex_string(&body))
445 .body(body)
446 .build()?;
447
448 if let Some(auth) = &self.config.s3.auth {
449 append_authentication_header(auth, date, &mut request)?;
450 }
451
452 let response = self.client.execute(request).await?;
453 if !response.status().is_success() {
454 return Err(response.into_error().await);
455 }
456
457 Ok(())
458 }
459}
460
461pub struct S3StorageBackend {
463 config: Arc<Config>,
465 client: HttpClient,
467 events: Option<broadcast::Sender<TransferEvent>>,
469}
470
471impl S3StorageBackend {
472 pub fn new(
474 config: Config,
475 client: HttpClient,
476 events: Option<broadcast::Sender<TransferEvent>>,
477 ) -> Self {
478 Self {
479 config: Arc::new(config),
480 client,
481 events,
482 }
483 }
484}
485
486impl StorageBackend for S3StorageBackend {
487 type Upload = S3Upload;
488
489 fn config(&self) -> &Config {
490 &self.config
491 }
492
493 fn cache(&self) -> Option<&Cache<DefaultCacheStorage>> {
494 self.client.cache()
495 }
496
497 fn events(&self) -> &Option<broadcast::Sender<TransferEvent>> {
498 &self.events
499 }
500
501 fn block_size(&self, file_size: u64) -> Result<u64> {
502 const BLOCK_COUNT_INCREMENT: u64 = 50;
504
505 if let Some(size) = self.config.block_size {
507 if size > MAX_PART_SIZE {
508 return Err(S3Error::InvalidBlockSize.into());
509 }
510
511 return Ok(size);
512 }
513
514 let mut num_blocks: u64 = BLOCK_COUNT_INCREMENT;
516 while num_blocks < MAX_PARTS {
517 let block_size = file_size.div_ceil(num_blocks).next_power_of_two();
518 if block_size <= BLOCK_SIZE_THRESHOLD {
519 return Ok(block_size.max(MIN_PART_SIZE));
520 }
521
522 num_blocks += BLOCK_COUNT_INCREMENT;
523 }
524
525 let block_size: u64 = file_size.div_ceil(MAX_PARTS);
528 if block_size > MAX_PART_SIZE {
529 return Err(S3Error::MaximumSizeExceeded.into());
530 }
531
532 Ok(block_size)
533 }
534
535 fn is_supported_url(config: &Config, url: &Url) -> bool {
536 match url.scheme() {
537 "s3" => true,
538 "http" | "https" => {
539 let Some(domain) = url.domain() else {
540 return false;
541 };
542
543 if domain.starts_with("s3.") || domain.starts_with("S3.") {
544 let domain = &domain[3..];
546 let Some((region, domain)) = domain.split_once('.') else {
547 return false;
548 };
549
550 !region.is_empty()
552 && (domain.eq_ignore_ascii_case(AWS_ROOT_DOMAIN)
553 || (config.s3.use_localstack
554 && domain.eq_ignore_ascii_case(LOCALSTACK_ROOT_DOMAIN)))
555 && url
556 .path_segments()
557 .map(|mut s| s.nth(1).is_some())
558 .unwrap_or(false)
559 } else {
560 let mut parts = domain.splitn(4, '.');
562 match (parts.next(), parts.next(), parts.next(), parts.next()) {
563 (Some(bucket), Some(service), Some(region), Some(domain)) => {
564 !bucket.is_empty()
566 && !region.is_empty()
567 && service.eq_ignore_ascii_case("s3")
568 && (domain.eq_ignore_ascii_case(AWS_ROOT_DOMAIN)
569 || (config.s3.use_localstack
570 && domain.eq_ignore_ascii_case(LOCALSTACK_ROOT_DOMAIN)))
571 && url
572 .path_segments()
573 .map(|mut s| s.next().is_some())
574 .unwrap_or(false)
575 }
576 _ => false,
577 }
578 }
579 }
580 _ => false,
581 }
582 }
583
584 fn rewrite_url(&self, url: Url) -> Result<Url> {
585 match url.scheme() {
586 "s3" => {
587 let region = self.config.s3.region.as_deref().unwrap_or(DEFAULT_REGION);
588 let bucket = url.host_str().ok_or(S3Error::InvalidScheme)?;
589 let path = url.path();
590
591 if url.path() == "/" {
592 return Err(S3Error::InvalidScheme.into());
593 }
594
595 let (scheme, root, port) = if self.config.azure.use_azurite {
596 ("http", LOCALSTACK_ROOT_DOMAIN, ":4566")
597 } else {
598 ("https", AWS_ROOT_DOMAIN, "")
599 };
600
601 match (url.query(), url.fragment()) {
602 (None, None) => format!("{scheme}://{bucket}.s3.{region}.{root}{port}{path}"),
603 (None, Some(fragment)) => {
604 format!("{scheme}://{bucket}.s3.{region}.{root}{port}{path}#{fragment}")
605 }
606 (Some(query), None) => {
607 format!("{scheme}://{bucket}.s3.{region}.{root}{port}{path}?{query}")
608 }
609 (Some(query), Some(fragment)) => {
610 format!(
611 "{scheme}://{bucket}.s3.{region}.{root}{port}{path}?{query}#{fragment}"
612 )
613 }
614 }
615 .parse()
616 .map_err(|_| S3Error::InvalidScheme.into())
617 }
618 _ => Ok(url),
619 }
620 }
621
622 fn join_url<'a>(&self, mut url: Url, segments: impl Iterator<Item = &'a str>) -> Result<Url> {
623 {
625 let mut existing = url.path_segments_mut().expect("url should have path");
626 existing.pop_if_empty();
627 existing.extend(segments);
628 }
629
630 Ok(url)
631 }
632
633 async fn head(&self, url: Url) -> Result<Response> {
634 debug_assert!(
635 Self::is_supported_url(&self.config, &url),
636 "{url} is not a supported S3 URL",
637 url = url.as_str()
638 );
639
640 debug!("sending HEAD request for `{url}`", url = url.display());
641
642 let date = Utc::now();
643 let mut request = self
644 .client
645 .head(url)
646 .header(header::USER_AGENT, USER_AGENT)
647 .header(AWS_DATE_HEADER, date.format("%Y%m%dT%H%M%SZ").to_string())
648 .header(AWS_CONTENT_SHA256_HEADER, sha256_hex_string([]))
649 .build()?;
650
651 if let Some(auth) = &self.config.s3.auth {
652 append_authentication_header(auth, date, &mut request)?;
653 }
654
655 let response = self.client.execute(request).await?;
656 if !response.status().is_success() {
657 return Err(response.into_error().await);
658 }
659
660 Ok(response)
661 }
662
663 async fn get(&self, url: Url) -> Result<Response> {
664 debug_assert!(
665 Self::is_supported_url(&self.config, &url),
666 "{url} is not a supported S3 URL",
667 url = url.as_str()
668 );
669
670 debug!("sending GET request for `{url}`", url = url.display());
671
672 let date = Utc::now();
673 let mut request = self
674 .client
675 .get(url)
676 .header(header::USER_AGENT, USER_AGENT)
677 .header(AWS_DATE_HEADER, date.format("%Y%m%dT%H%M%SZ").to_string())
678 .header(AWS_CONTENT_SHA256_HEADER, sha256_hex_string([]))
679 .build()?;
680
681 if let Some(auth) = &self.config.s3.auth {
682 append_authentication_header(auth, date, &mut request)?;
683 }
684
685 let response = self.client.execute(request).await?;
686 if !response.status().is_success() {
687 return Err(response.into_error().await);
688 }
689
690 Ok(response)
691 }
692
693 async fn get_at_offset(&self, url: Url, etag: &str, offset: u64) -> Result<Response> {
694 debug_assert!(
695 Self::is_supported_url(&self.config, &url),
696 "{url} is not a supported S3 URL",
697 url = url.as_str()
698 );
699
700 debug!(
701 "sending GET request at offset {offset} for `{url}`",
702 url = url.display(),
703 );
704
705 let date = Utc::now();
706
707 let mut request = self
708 .client
709 .get(url)
710 .header(header::USER_AGENT, USER_AGENT)
711 .header(AWS_DATE_HEADER, date.format("%Y%m%dT%H%M%SZ").to_string())
712 .header(AWS_CONTENT_SHA256_HEADER, sha256_hex_string([]))
713 .header(header::RANGE, format!("bytes={offset}-"))
714 .header(header::IF_MATCH, etag)
715 .build()?;
716
717 if let Some(auth) = &self.config.s3.auth {
718 append_authentication_header(auth, date, &mut request)?;
719 }
720
721 let response = self.client.execute(request).await?;
722 let status = response.status();
723
724 if status == StatusCode::PRECONDITION_FAILED {
726 return Err(Error::RemoteContentModified);
727 }
728
729 if !status.is_success() {
731 return Err(response.into_error().await);
732 }
733
734 if status != StatusCode::PARTIAL_CONTENT {
736 return Err(Error::RemoteContentModified);
737 }
738
739 Ok(response)
740 }
741
742 async fn walk(&self, mut url: Url) -> Result<Vec<String>> {
743 debug_assert!(
746 Self::is_supported_url(&self.config, &url),
747 "{url} is not a supported S3 URL",
748 url = url.as_str()
749 );
750
751 debug!("walking `{url}` as a directory", url = url.display());
752
753 let (bucket, path) = url.bucket_and_path();
754
755 let mut prefix = path.strip_prefix('/').unwrap_or(path).to_string();
757 prefix.push('/');
758
759 let domain = url.domain().expect("URL should have domain");
761 if domain.starts_with("s3") || domain.starts_with("S3") {
762 url.set_host(Some(&format!("{bucket}.{domain}")))
764 .map_err(|_| S3Error::InvalidBucketName)?;
765 }
766
767 url.set_path("/");
768
769 {
770 let mut pairs = url.query_pairs_mut();
771 pairs.append_pair("list-type", "2");
773 pairs.append_pair("prefix", &prefix);
775 }
776
777 let date = Utc::now();
778 let mut token = String::new();
779 let mut paths = Vec::new();
780 loop {
781 let mut url = url.clone();
782 if !token.is_empty() {
783 url.query_pairs_mut()
784 .append_pair("continuation-token", &token);
785 }
786
787 let mut request = self
789 .client
790 .get(url)
791 .header(header::USER_AGENT, USER_AGENT)
792 .header(AWS_DATE_HEADER, date.format("%Y%m%dT%H%M%SZ").to_string())
793 .header(AWS_CONTENT_SHA256_HEADER, sha256_hex_string([]))
794 .build()?;
795
796 if let Some(auth) = &self.config.s3.auth {
797 append_authentication_header(auth, date, &mut request)?;
798 }
799
800 let response = self.client.execute(request).await?;
801
802 let status = response.status();
803 if !status.is_success() {
804 return Err(response.into_error().await);
805 }
806
807 let text = response.text().await?;
808 let results: ListBucketResult = match serde_xml_rs::from_str(&text) {
809 Ok(response) => response,
810 Err(e) => {
811 return Err(S3Error::UnexpectedResponse { status, error: e }.into());
812 }
813 };
814
815 if paths.is_empty()
818 && results.contents.len() == 1
819 && results.token.is_none()
820 && let Some("") = results.contents[0].key.strip_prefix(&prefix)
821 {
822 return Ok(paths);
823 }
824
825 paths.extend(
826 results
827 .contents
828 .into_iter()
829 .map(|c| c.key.strip_prefix(&prefix).map(Into::into).unwrap_or(c.key)),
830 );
831
832 token = results.token.unwrap_or_default();
833 if token.is_empty() {
834 break;
835 }
836 }
837
838 Ok(paths)
839 }
840
841 async fn new_upload(&self, url: Url) -> Result<Self::Upload> {
842 debug_assert!(
845 Self::is_supported_url(&self.config, &url),
846 "{url} is not a supported S3 URL",
847 url = url.as_str()
848 );
849
850 debug!("sending POST request for `{url}`", url = url.display());
851
852 let mut create = url.clone();
853 create.query_pairs_mut().append_key_only("uploads");
854
855 let date = Utc::now();
856 let mut request = self
857 .client
858 .post(create)
859 .header(header::USER_AGENT, USER_AGENT)
860 .header(AWS_DATE_HEADER, date.format("%Y%m%dT%H%M%SZ").to_string())
861 .header(AWS_CONTENT_SHA256_HEADER, sha256_hex_string([]))
862 .build()?;
863
864 if let Some(auth) = &self.config.s3.auth {
865 append_authentication_header(auth, date, &mut request)?;
866 }
867
868 let response = self.client.execute(request).await?;
869
870 let status = response.status();
871 if !status.is_success() {
872 return Err(response.into_error().await);
873 }
874
875 let text: String = match response.text().await {
876 Ok(text) => text,
877 Err(e) => return Err(e.into()),
878 };
879
880 let id = match serde_xml_rs::from_str::<InitiateMultipartUploadResult>(&text) {
881 Ok(response) => response.upload_id,
882 Err(e) => {
883 return Err(S3Error::UnexpectedResponse { status, error: e }.into());
884 }
885 };
886
887 Ok(S3Upload {
888 config: self.config.clone(),
889 client: self.client.clone(),
890 url,
891 id,
892 events: self.events.clone(),
893 })
894 }
895}