use std::borrow::Cow;
use bytes::Bytes;
use chrono::DateTime;
use chrono::Utc;
use http_cache_stream_reqwest::Cache;
use http_cache_stream_reqwest::storage::DefaultCacheStorage;
use reqwest::Body;
use reqwest::Request;
use reqwest::Response;
use reqwest::StatusCode;
use reqwest::header;
use reqwest::header::HeaderValue;
use secrecy::ExposeSecret;
use serde::Deserialize;
use serde::Serialize;
use tokio::sync::broadcast;
use tracing::debug;
use url::Url;
use crate::BLOCK_SIZE_THRESHOLD;
use crate::Config;
use crate::Error;
use crate::HttpClient;
use crate::ONE_MEBIBYTE;
use crate::Result;
use crate::S3AuthConfig;
use crate::TransferEvent;
use crate::USER_AGENT;
use crate::UrlExt as _;
use crate::backend::StorageBackend;
use crate::backend::Upload;
use crate::backend::auth::s3::RequestSigner;
use crate::backend::auth::s3::SignatureProvider;
use crate::backend::format_range_header;
use crate::sha256_hex_string;
use crate::streams::ByteStream;
use crate::streams::TransferStream;
const AWS_ROOT_DOMAIN: &str = "amazonaws.com";
const LOCALSTACK_ROOT_DOMAIN: &str = "localhost.localstack.cloud";
const MAX_PARTS: u64 = 10000;
const MIN_PART_SIZE: u64 = 5 * ONE_MEBIBYTE;
const MAX_PART_SIZE: u64 = MIN_PART_SIZE * 1024;
const MAX_FILE_SIZE: u64 = MAX_PART_SIZE * 1024;
const AWS_DATE_HEADER: &str = "x-amz-date";
const AWS_CONTENT_SHA256_HEADER: &str = "x-amz-content-sha256";
pub(crate) const AWS_CONTENT_DIGEST_HEADER: &str = "x-amz-meta-content-digest";
#[derive(Debug, thiserror::Error)]
pub enum S3Error {
#[error("S3 block size cannot exceed {MAX_PART_SIZE} bytes")]
InvalidBlockSize,
#[error("the size of the source file exceeds the supported maximum of {MAX_FILE_SIZE} bytes")]
MaximumSizeExceeded,
#[error("invalid URL with `s3` scheme: the URL is not in a supported format")]
InvalidScheme,
#[error("URL is missing the bucket in the path")]
MissingBucket,
#[error("invalid S3 secret access key")]
InvalidSecretAccessKey,
#[error("response from server was missing an ETag header")]
ResponseMissingETag,
#[error("the bucket name specified in the URL is invalid")]
InvalidBucketName,
#[error("unexpected {status} response from server: failed to deserialize response contents: {error}", status = .status.as_u16())]
UnexpectedResponse {
status: reqwest::StatusCode,
error: serde_xml_rs::Error,
},
}
#[derive(Debug, Deserialize)]
pub struct Content {
#[serde(rename = "Key")]
pub key: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename = "ListBucketResult")]
pub struct ListBucketResult {
#[serde(default, rename = "Contents")]
pub contents: Vec<Content>,
#[serde(rename = "NextContinuationToken", default)]
pub token: Option<String>,
}
#[derive(Default, Deserialize)]
#[serde(rename = "InitiateMultipartUploadResult")]
pub struct InitiateMultipartUploadResult {
#[serde(rename = "UploadId")]
pub upload_id: String,
}
pub struct S3SignatureProvider<'a> {
region: &'a str,
auth: &'a S3AuthConfig,
}
impl SignatureProvider for S3SignatureProvider<'_> {
fn algorithm(&self) -> &str {
"AWS4-HMAC-SHA256"
}
fn secret_key_prefix(&self) -> &str {
"AWS4"
}
fn request_type(&self) -> &str {
"aws4_request"
}
fn region(&self) -> &str {
self.region
}
fn service(&self) -> &str {
"s3"
}
fn date_header_name(&self) -> &str {
AWS_DATE_HEADER
}
fn content_hash_header_name(&self) -> &str {
AWS_CONTENT_SHA256_HEADER
}
fn access_key_id(&self) -> &str {
self.auth.access_key_id()
}
fn secret_access_key(&self) -> &str {
self.auth.secret_access_key().expose_secret()
}
}
fn insert_authentication_header(
auth: &S3AuthConfig,
date: DateTime<Utc>,
request: &mut Request,
) -> Result<()> {
let signer = RequestSigner::new(S3SignatureProvider {
region: request.url().region(),
auth,
});
let auth = signer
.sign(date, request)
.ok_or(S3Error::InvalidSecretAccessKey)?;
request.headers_mut().insert(
header::AUTHORIZATION,
HeaderValue::try_from(auth).expect("value should be valid"),
);
Ok(())
}
trait UrlExt {
fn region(&self) -> &str;
fn bucket_and_path(&self) -> (&str, &str);
}
impl UrlExt for Url {
fn region(&self) -> &str {
let domain = self.domain().expect("URL should have domain");
if domain.starts_with("s3.") || domain.starts_with("S3.") {
let mut parts = domain.splitn(3, '.');
match (parts.next(), parts.next()) {
(_, Some(region)) => region,
_ => panic!("invalid S3 URL"),
}
} else {
let mut parts = domain.splitn(4, '.');
match (parts.next(), parts.next(), parts.next()) {
(_, _, Some(region)) => region,
_ => panic!("invalid S3 URL"),
}
}
}
fn bucket_and_path(&self) -> (&str, &str) {
let domain = self.domain().expect("URL should have domain");
if domain.starts_with("s3.") || domain.starts_with("S3.") {
let bucket = self
.path_segments()
.expect("URL should have path")
.next()
.expect("URL should have at least one path segment");
(
bucket,
self.path()
.strip_prefix('/')
.unwrap()
.strip_prefix(bucket)
.unwrap(),
)
} else {
let Some((bucket, _)) = domain.split_once('.') else {
panic!("URL domain does not contain a bucket");
};
(bucket, self.path())
}
}
}
trait ResponseExt {
async fn into_error(self) -> Error;
}
impl ResponseExt for Response {
async fn into_error(self) -> Error {
#[derive(Default, Deserialize)]
#[serde(rename = "Error")]
struct ErrorResponse {
#[serde(rename = "Message")]
message: String,
}
let status = self.status();
if status == StatusCode::MOVED_PERMANENTLY {
return Error::Server {
status,
message: "the AWS region being used may not be the correct region for the storage \
bucket"
.into(),
};
}
let text: String = match self.text().await {
Ok(text) => text,
Err(e) => return e.into(),
};
if text.is_empty() {
return Error::Server {
status,
message: text,
};
}
let message = match serde_xml_rs::from_str::<ErrorResponse>(&text) {
Ok(response) => response.message,
Err(e) => {
return S3Error::UnexpectedResponse { status, error: e }.into();
}
};
Error::Server { status, message }
}
}
#[derive(Default, Clone, Serialize)]
#[serde(rename = "Part")]
pub struct S3UploadPart {
#[serde(rename = "PartNumber")]
number: u64,
#[serde(rename = "ETag")]
etag: String,
}
pub struct S3Upload {
config: Config,
client: HttpClient,
url: Url,
id: String,
events: Option<broadcast::Sender<TransferEvent>>,
}
impl Upload for S3Upload {
type Part = S3UploadPart;
async fn put(&self, id: u64, block: u64, bytes: Bytes) -> Result<Option<Self::Part>> {
debug!(
"sending PUT request for block {block} of `{url}`",
url = self.url.display()
);
let mut url = self.url.clone();
{
let mut pairs = url.query_pairs_mut();
pairs.append_pair("partNumber", &format!("{number}", number = block + 1));
pairs.append_pair("uploadId", &self.id);
}
let digest = sha256_hex_string(&bytes);
let length = bytes.len();
let body = Body::wrap_stream(TransferStream::new(
ByteStream::new(bytes),
id,
block,
0,
self.events.clone(),
));
let date = Utc::now();
let mut request = self
.client
.put(url)
.header(header::USER_AGENT, USER_AGENT)
.header(header::CONTENT_LENGTH, length)
.header(header::CONTENT_TYPE, "application/octet-stream")
.header(AWS_DATE_HEADER, date.format("%Y%m%dT%H%M%SZ").to_string())
.header(AWS_CONTENT_SHA256_HEADER, &digest)
.body(body)
.build()?;
if let Some(auth) = self.config.s3().auth() {
insert_authentication_header(auth, date, &mut request)?;
}
let response = self.client.execute(request).await?;
if !response.status().is_success() {
return Err(response.into_error().await);
}
let etag = response
.headers()
.get(header::ETAG)
.and_then(|v| v.to_str().ok())
.ok_or(S3Error::ResponseMissingETag)?;
Ok(Some(S3UploadPart {
number: block + 1,
etag: etag.to_string(),
}))
}
async fn finalize(&self, parts: &[Self::Part]) -> Result<()> {
#[derive(Serialize)]
#[serde(rename = "CompleteMultipartUpload")]
struct CompleteUpload<'a> {
#[serde(rename = "Part")]
parts: &'a [S3UploadPart],
}
debug!(
"sending POST request to finalize upload of `{url}`",
url = self.url.display()
);
let mut url = self.url.clone();
{
let mut pairs = url.query_pairs_mut();
pairs.append_pair("uploadId", &self.id);
}
let body = serde_xml_rs::SerdeXml::new()
.default_namespace("http://s3.amazonaws.com/doc/2006-03-01/")
.to_string(&CompleteUpload { parts })
.expect("should serialize");
let date = Utc::now();
let mut request = self
.client
.post(url)
.header(header::USER_AGENT, USER_AGENT)
.header(header::CONTENT_LENGTH, body.len())
.header(header::CONTENT_TYPE, "application/xml")
.header(AWS_DATE_HEADER, date.format("%Y%m%dT%H%M%SZ").to_string())
.header(AWS_CONTENT_SHA256_HEADER, sha256_hex_string(&body))
.body(body)
.build()?;
if let Some(auth) = self.config.s3().auth() {
insert_authentication_header(auth, date, &mut request)?;
}
let response = self.client.execute(request).await?;
if !response.status().is_success() {
return Err(response.into_error().await);
}
Ok(())
}
}
pub struct S3StorageBackend {
config: Config,
client: HttpClient,
events: Option<broadcast::Sender<TransferEvent>>,
}
impl S3StorageBackend {
pub fn new(
config: Config,
client: HttpClient,
events: Option<broadcast::Sender<TransferEvent>>,
) -> Self {
Self {
config,
client,
events,
}
}
}
impl StorageBackend for S3StorageBackend {
type Upload = S3Upload;
fn config(&self) -> &Config {
&self.config
}
fn cache(&self) -> Option<&Cache<DefaultCacheStorage>> {
self.client.cache()
}
fn events(&self) -> &Option<broadcast::Sender<TransferEvent>> {
&self.events
}
fn block_size(&self, file_size: u64) -> Result<u64> {
const BLOCK_COUNT_INCREMENT: u64 = 50;
if let Some(size) = self.config.block_size() {
if size > MAX_PART_SIZE {
return Err(S3Error::InvalidBlockSize.into());
}
return Ok(size);
}
let mut num_blocks: u64 = BLOCK_COUNT_INCREMENT;
while num_blocks < MAX_PARTS {
let block_size = file_size.div_ceil(num_blocks).next_power_of_two();
if block_size <= BLOCK_SIZE_THRESHOLD {
return Ok(block_size.max(MIN_PART_SIZE));
}
num_blocks += BLOCK_COUNT_INCREMENT;
}
let block_size: u64 = file_size.div_ceil(MAX_PARTS);
if block_size > MAX_PART_SIZE {
return Err(S3Error::MaximumSizeExceeded.into());
}
Ok(block_size)
}
fn is_supported_url(config: &Config, url: &Url) -> bool {
match url.scheme() {
"s3" => true,
"http" | "https" => {
let Some(domain) = url.domain() else {
return false;
};
if domain.starts_with("s3.") || domain.starts_with("S3.") {
let domain = &domain[3..];
let Some((region, domain)) = domain.split_once('.') else {
return false;
};
!region.is_empty()
&& (domain.eq_ignore_ascii_case(AWS_ROOT_DOMAIN)
|| (config.s3().use_localstack()
&& domain.eq_ignore_ascii_case(LOCALSTACK_ROOT_DOMAIN)))
&& url
.path_segments()
.map(|mut s| s.nth(1).is_some())
.unwrap_or(false)
} else {
let mut parts = domain.splitn(4, '.');
match (parts.next(), parts.next(), parts.next(), parts.next()) {
(Some(bucket), Some(service), Some(region), Some(domain)) => {
!bucket.is_empty()
&& !region.is_empty()
&& service.eq_ignore_ascii_case("s3")
&& (domain.eq_ignore_ascii_case(AWS_ROOT_DOMAIN)
|| (config.s3().use_localstack()
&& domain.eq_ignore_ascii_case(LOCALSTACK_ROOT_DOMAIN)))
&& url
.path_segments()
.map(|mut s| s.next().is_some())
.unwrap_or(false)
}
_ => false,
}
}
}
_ => false,
}
}
fn rewrite_url<'a>(config: &Config, url: &'a Url) -> Result<Cow<'a, Url>> {
match url.scheme() {
"s3" => {
let region = config.s3().region();
let bucket = url.host_str().ok_or(S3Error::InvalidScheme)?;
let path = url.path();
if url.path() == "/" {
return Err(S3Error::InvalidScheme.into());
}
let (scheme, root, port) = if config.s3().use_localstack() {
("http", LOCALSTACK_ROOT_DOMAIN, ":4566")
} else {
("https", AWS_ROOT_DOMAIN, "")
};
match (url.query(), url.fragment()) {
(None, None) => format!("{scheme}://{bucket}.s3.{region}.{root}{port}{path}"),
(None, Some(fragment)) => {
format!("{scheme}://{bucket}.s3.{region}.{root}{port}{path}#{fragment}")
}
(Some(query), None) => {
format!("{scheme}://{bucket}.s3.{region}.{root}{port}{path}?{query}")
}
(Some(query), Some(fragment)) => {
format!(
"{scheme}://{bucket}.s3.{region}.{root}{port}{path}?{query}#{fragment}"
)
}
}
.parse()
.map(Cow::Owned)
.map_err(|_| S3Error::InvalidScheme.into())
}
_ => Ok(Cow::Borrowed(url)),
}
}
fn join_url<'a>(&self, mut url: Url, segments: impl Iterator<Item = &'a str>) -> Result<Url> {
{
let mut existing = url.path_segments_mut().expect("url should have path");
existing.pop_if_empty();
existing.extend(segments);
}
Ok(url)
}
async fn head(&self, url: Url, must_exist: bool) -> Result<Response> {
debug_assert!(
Self::is_supported_url(&self.config, &url),
"{url} is not a supported S3 URL",
url = url.as_str()
);
debug!("sending HEAD request for `{url}`", url = url.display());
let date = Utc::now();
let mut request = self
.client
.head(url)
.header(header::USER_AGENT, USER_AGENT)
.header(AWS_DATE_HEADER, date.format("%Y%m%dT%H%M%SZ").to_string())
.header(AWS_CONTENT_SHA256_HEADER, sha256_hex_string([]))
.build()?;
if let Some(auth) = self.config.s3().auth() {
insert_authentication_header(auth, date, &mut request)?;
}
let response = self.client.execute(request).await?;
if !response.status().is_success() {
if !must_exist && response.status() == StatusCode::NOT_FOUND {
return Ok(response);
}
return Err(response.into_error().await);
}
Ok(response)
}
async fn get(&self, url: Url) -> Result<Response> {
debug_assert!(
Self::is_supported_url(&self.config, &url),
"{url} is not a supported S3 URL",
url = url.as_str()
);
debug!("sending GET request for `{url}`", url = url.display());
let date = Utc::now();
let mut request = self
.client
.get(url)
.header(header::USER_AGENT, USER_AGENT)
.header(AWS_DATE_HEADER, date.format("%Y%m%dT%H%M%SZ").to_string())
.header(AWS_CONTENT_SHA256_HEADER, sha256_hex_string([]))
.build()?;
if let Some(auth) = self.config.s3().auth() {
insert_authentication_header(auth, date, &mut request)?;
}
let response = self.client.execute(request).await?;
if !response.status().is_success() {
return Err(response.into_error().await);
}
Ok(response)
}
async fn get_range(
&self,
url: Url,
etag: &str,
start: u64,
exclusive_end: Option<u64>,
) -> Result<Response> {
debug_assert!(
Self::is_supported_url(&self.config, &url),
"{url} is not a supported S3 URL",
url = url.as_str()
);
let range = format_range_header(start, exclusive_end);
debug!(
"sending GET request with range `{range}` for `{url}`",
url = url.display(),
);
let date = Utc::now();
let mut request = self
.client
.get(url)
.header(header::USER_AGENT, USER_AGENT)
.header(AWS_DATE_HEADER, date.format("%Y%m%dT%H%M%SZ").to_string())
.header(AWS_CONTENT_SHA256_HEADER, sha256_hex_string([]))
.header(header::RANGE, range)
.header(header::IF_MATCH, etag)
.build()?;
if let Some(auth) = self.config.s3().auth() {
insert_authentication_header(auth, date, &mut request)?;
}
let response = self.client.execute(request).await?;
let status = response.status();
if status == StatusCode::PRECONDITION_FAILED {
return Err(Error::RemoteContentModified);
}
if !status.is_success() {
return Err(response.into_error().await);
}
if status != StatusCode::PARTIAL_CONTENT {
return Err(Error::RemoteContentModified);
}
Ok(response)
}
async fn walk(&self, mut url: Url, first_only: bool) -> Result<Vec<String>> {
debug_assert!(
Self::is_supported_url(&self.config, &url),
"{url} is not a supported S3 URL",
url = url.as_str()
);
debug!("walking `{url}` as a directory", url = url.display());
let (bucket, path) = url.bucket_and_path();
let mut prefix = path.strip_prefix('/').unwrap_or(path).to_string();
if !prefix.ends_with('/') {
prefix.push('/');
}
let domain = url.domain().expect("URL should have domain");
if domain.starts_with("s3") || domain.starts_with("S3") {
url.set_host(Some(&format!("{bucket}.{domain}")))
.map_err(|_| S3Error::InvalidBucketName)?;
}
url.set_path("/");
{
let mut pairs = url.query_pairs_mut();
pairs.append_pair("list-type", "2");
pairs.append_pair("prefix", &prefix);
if first_only {
pairs.append_pair("max-keys", "1");
}
}
let date = Utc::now();
let mut token = String::new();
let mut paths = Vec::new();
loop {
let mut url = url.clone();
if !token.is_empty() {
url.query_pairs_mut()
.append_pair("continuation-token", &token);
}
let mut request = self
.client
.get(url)
.header(header::USER_AGENT, USER_AGENT)
.header(AWS_DATE_HEADER, date.format("%Y%m%dT%H%M%SZ").to_string())
.header(AWS_CONTENT_SHA256_HEADER, sha256_hex_string([]))
.build()?;
if let Some(auth) = self.config.s3().auth() {
insert_authentication_header(auth, date, &mut request)?;
}
let response = self.client.execute(request).await?;
let status = response.status();
if !status.is_success() {
return Err(response.into_error().await);
}
let text = response.text().await?;
let results: ListBucketResult = match serde_xml_rs::from_str(&text) {
Ok(response) => response,
Err(e) => {
return Err(S3Error::UnexpectedResponse { status, error: e }.into());
}
};
if paths.is_empty()
&& results.contents.len() == 1
&& results.token.is_none()
&& let Some("") = results.contents[0].key.strip_prefix(&prefix)
{
return Ok(paths);
}
paths.extend(results.contents.into_iter().map(|c| {
let key = c.key.strip_prefix(&prefix).unwrap_or(&c.key);
key.strip_prefix('/').unwrap_or(key).into()
}));
token = results.token.unwrap_or_default();
if first_only || token.is_empty() {
break;
}
}
Ok(paths)
}
async fn new_upload(&self, url: Url, digest: Option<String>) -> Result<Self::Upload> {
debug_assert!(
Self::is_supported_url(&self.config, &url),
"{url} is not a supported S3 URL",
url = url.as_str()
);
if !self.config.overwrite() {
let response = self.head(url.clone(), false).await?;
if response.status() != StatusCode::NOT_FOUND {
return Err(Error::RemoteDestinationExists(url));
}
}
debug!("sending POST request for `{url}`", url = url.display());
let mut create = url.clone();
create.query_pairs_mut().append_key_only("uploads");
let date = Utc::now();
let mut request = self
.client
.post(create)
.header(header::USER_AGENT, USER_AGENT)
.header(AWS_DATE_HEADER, date.format("%Y%m%dT%H%M%SZ").to_string())
.header(AWS_CONTENT_SHA256_HEADER, sha256_hex_string([]))
.build()?;
if let Some(digest) = digest {
request.headers_mut().insert(
AWS_CONTENT_DIGEST_HEADER,
digest
.try_into()
.expect("invalid content digest header value"),
);
}
if let Some(auth) = self.config.s3().auth() {
insert_authentication_header(auth, date, &mut request)?;
}
let response = self.client.execute(request).await?;
let status = response.status();
if !status.is_success() {
return Err(response.into_error().await);
}
let text: String = match response.text().await {
Ok(text) => text,
Err(e) => return Err(e.into()),
};
let id = match serde_xml_rs::from_str::<InitiateMultipartUploadResult>(&text) {
Ok(response) => response.upload_id,
Err(e) => {
return Err(S3Error::UnexpectedResponse { status, error: e }.into());
}
};
Ok(S3Upload {
config: self.config.clone(),
client: self.client.clone(),
url,
id,
events: self.events.clone(),
})
}
}