use std::io::SeekFrom;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use aws_config::timeout::TimeoutConfig;
use aws_config::{BehaviorVersion, Region};
use aws_sdk_s3::error::{ProvideErrorMetadata, SdkError};
use aws_sdk_s3::primitives::{ByteStream, Length};
use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart, MetadataDirective};
use aws_smithy_http_client::tls::{Provider as TlsProvider, rustls_provider::CryptoMode};
use aws_smithy_types_convert::date_time::DateTimeExt;
use bytes::Bytes;
use percent_encoding::{AsciiSet, CONTROLS, utf8_percent_encode};
use tempfile::NamedTempFile;
use tokio::io::{AsyncSeekExt, AsyncWriteExt};
use tokio::sync::{Mutex, Semaphore};
use tokio::task::JoinSet;
use url::Url;
use crate::url::{
AWS_S3_INFIXES, RemoteUrl, S3Addressing, s3_virtual_hosted_bucket, strip_aws_host_suffix,
};
use super::error::{network_boxed, other_boxed};
use super::multipart::{
MULTIPART_PUT_MAX_CONCURRENCY, MULTIPART_PUT_PART_SIZE, S3_MAX_PARTS, UploadPart,
plan_upload_parts, read_file_part, should_use_multipart, slice_bytes_part,
};
use super::{
GetOpts, ObjectMeta, ObjectStore, ObjectStoreError, ProgressSink, PutOpts, persist_temp,
};
pub(crate) const MULTIPART_THRESHOLD: u64 = 25 * 1024 * 1024;
pub(crate) const MULTIPART_CHUNK_SIZE: u64 = 16 * 1024 * 1024;
pub(crate) const MULTIPART_MAX_CONCURRENCY: usize = 8;
pub(crate) const SINGLE_PUT_LIMIT_BYTES: u64 = 5 * (1 << 30);
const COPY_SOURCE_ENCODE: &AsciiSet = &CONTROLS
.add(b' ')
.add(b'!')
.add(b'"')
.add(b'#')
.add(b'$')
.add(b'%')
.add(b'&')
.add(b'\'')
.add(b'(')
.add(b')')
.add(b'*')
.add(b'+')
.add(b',')
.add(b':')
.add(b';')
.add(b'<')
.add(b'=')
.add(b'>')
.add(b'?')
.add(b'@')
.add(b'[')
.add(b'\\')
.add(b']')
.add(b'^')
.add(b'`')
.add(b'{')
.add(b'|')
.add(b'}');
pub(crate) const POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
pub(crate) const READ_TIMEOUT: Duration = Duration::from_secs(30);
fn upload_timeout_config() -> TimeoutConfig {
TimeoutConfig::builder().disable_read_timeout().build()
}
#[derive(Debug)]
pub struct S3Store {
client: aws_sdk_s3::Client,
bucket: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct ResolvedS3Config {
pub(crate) endpoint_url: Url,
pub(crate) region: Option<String>,
pub(crate) force_path_style: bool,
pub(crate) profile: Option<String>,
}
impl ResolvedS3Config {
pub(crate) fn from_url_parts(
endpoint: &Url,
addressing: S3Addressing,
profile: Option<&str>,
region_flag: Option<&str>,
) -> Result<Self, ObjectStoreError> {
Ok(Self {
endpoint_url: normalize_endpoint(endpoint, addressing)?,
region: resolve_region(endpoint, region_flag),
force_path_style: matches!(addressing, S3Addressing::PathStyle),
profile: profile.map(str::to_owned),
})
}
}
impl S3Store {
pub async fn from_remote_url(url: &RemoteUrl) -> Result<Self, ObjectStoreError> {
let RemoteUrl::S3 {
endpoint,
bucket,
addressing,
flags,
..
} = url
else {
return Err(ObjectStoreError::Other(
format!("S3Store::from_remote_url called with non-S3 URL: {url}").into(),
));
};
let resolved = ResolvedS3Config::from_url_parts(
endpoint,
*addressing,
flags.profile.as_deref(),
flags.region.as_deref(),
)?;
let sdk_config = build_s3_config(&resolved).await;
let client = aws_sdk_s3::Client::from_conf(sdk_config);
Ok(Self {
client,
bucket: bucket.clone(),
})
}
pub(crate) async fn probe(&self, prefix: &str) -> Result<(), ObjectStoreError> {
self.client
.list_objects_v2()
.bucket(&self.bucket)
.prefix(prefix)
.max_keys(1)
.send()
.await
.map_err(|e| classify(e, prefix))?;
Ok(())
}
}
pub(crate) async fn build_s3_config(resolved: &ResolvedS3Config) -> aws_sdk_s3::Config {
let mut loader = aws_config::defaults(BehaviorVersion::latest())
.http_client(
aws_smithy_http_client::Builder::new()
.tls_provider(TlsProvider::Rustls(CryptoMode::AwsLc))
.pool_idle_timeout(POOL_IDLE_TIMEOUT)
.build_https(),
)
.timeout_config(TimeoutConfig::builder().read_timeout(READ_TIMEOUT).build())
.endpoint_url(resolved.endpoint_url.as_str());
if let Some(p) = &resolved.profile {
loader = loader.profile_name(p);
}
if let Some(r) = &resolved.region {
loader = loader.region(Region::new(r.clone()));
}
let sdk_config = loader.load().await;
aws_sdk_s3::config::Builder::from(&sdk_config)
.force_path_style(resolved.force_path_style)
.build()
}
pub(crate) fn normalize_endpoint(
endpoint: &Url,
addressing: S3Addressing,
) -> Result<Url, ObjectStoreError> {
let mut rewritten = endpoint.clone();
rewritten.set_path("");
rewritten.set_query(None);
rewritten.set_fragment(None);
if matches!(addressing, S3Addressing::VirtualHosted) {
let host = rewritten
.host_str()
.ok_or_else(|| ObjectStoreError::Other("endpoint URL has no host".into()))?;
let regional_host = s3_virtual_hosted_bucket(host)
.map(|bucket| host[bucket.len() + 1..].to_owned())
.or_else(|| host.split_once('.').map(|(_, rest)| rest.to_owned()))
.ok_or_else(|| {
ObjectStoreError::Other(
format!("virtual-hosted endpoint host `{host}` has no dot separator").into(),
)
})?;
rewritten
.set_host(Some(®ional_host))
.map_err(other_boxed)?;
}
Ok(rewritten)
}
pub(crate) fn resolve_region(endpoint: &Url, flag: Option<&str>) -> Option<String> {
if let Some(r) = flag {
return Some(r.to_owned());
}
let host = endpoint.host_str()?;
if host == "amazonaws.com" {
return None;
}
let Some(trimmed) = strip_aws_host_suffix(host) else {
return Some("us-east-1".to_owned());
};
extract_aws_region(trimmed)
}
fn extract_aws_region(trimmed: &str) -> Option<String> {
let labels: Vec<&str> = trimmed.split('.').collect();
match labels.as_slice() {
["s3"] => None,
["s3", region] => Some((*region).to_owned()),
[_bucket, "s3", region] => Some((*region).to_owned()),
[head] if head.starts_with("s3-") => Some(head["s3-".len()..].to_owned()),
_ => AWS_S3_INFIXES
.iter()
.filter_map(|infix| {
trimmed
.rfind(infix)
.map(|idx| (idx, trimmed[idx + infix.len()..].to_owned()))
})
.max_by_key(|(idx, _)| *idx)
.map(|(_, region)| region)
.filter(|region| !region.is_empty() && !region.contains('.')),
}
}
pub(crate) fn plan_ranges(size: u64, chunk_size: u64) -> Vec<(u64, u64)> {
if size == 0 || chunk_size == 0 {
return Vec::new();
}
let mut ranges = Vec::new();
let mut start = 0u64;
while start < size {
let end = (start + chunk_size - 1).min(size - 1);
ranges.push((start, end));
start = end + 1;
}
ranges
}
pub(crate) fn encode_copy_source(bucket: &str, key: &str) -> String {
let bucket_enc = utf8_percent_encode(bucket, COPY_SOURCE_ENCODE);
let key_enc = utf8_percent_encode(key, COPY_SOURCE_ENCODE);
format!("{bucket_enc}/{key_enc}")
}
fn classify<E>(err: SdkError<E>, key: &str) -> ObjectStoreError
where
E: std::error::Error + Send + Sync + 'static + ProvideErrorMetadata,
{
if let SdkError::ServiceError(svc) = &err {
let status = svc.raw().status().as_u16();
let code = svc.err().code();
if let Some(mapped) = classify_status_and_code(status, code, key) {
return mapped;
}
}
match &err {
SdkError::DispatchFailure(_) | SdkError::TimeoutError(_) => network_boxed(err),
_ => other_boxed(err),
}
}
pub(crate) fn object_to_meta(
obj: &aws_sdk_s3::types::Object,
) -> Result<ObjectMeta, ObjectStoreError> {
let key = obj
.key()
.ok_or_else(|| {
ObjectStoreError::Other("list_objects_v2 returned an object without a key".into())
})?
.to_owned();
let size = u64::try_from(obj.size().unwrap_or(0)).unwrap_or(0);
let last_modified = obj
.last_modified()
.ok_or_else(|| {
ObjectStoreError::Other(
format!("list_objects_v2 returned object `{key}` without last_modified").into(),
)
})?
.to_time()
.map_err(other_boxed)?;
Ok(ObjectMeta {
key,
size,
last_modified,
etag: None,
})
}
pub(crate) fn head_output_to_meta(
key: &str,
content_length: Option<i64>,
last_modified: Option<&aws_sdk_s3::primitives::DateTime>,
etag: Option<&str>,
) -> Result<ObjectMeta, ObjectStoreError> {
let raw_size = content_length.ok_or_else(|| {
ObjectStoreError::Other(format!("head_object on `{key}` returned no content-length").into())
})?;
let size = u64::try_from(raw_size).unwrap_or(0);
let last_modified = last_modified
.ok_or_else(|| {
ObjectStoreError::Other(
format!("head_object on `{key}` returned no last_modified").into(),
)
})?
.to_time()
.map_err(other_boxed)?;
Ok(ObjectMeta {
key: key.to_owned(),
size,
last_modified,
etag: etag.map(str::to_owned),
})
}
fn classify_status_and_code(
status: u16,
code: Option<&str>,
key: &str,
) -> Option<ObjectStoreError> {
match status {
404 => return Some(ObjectStoreError::NotFound(key.to_owned())),
403 => return Some(ObjectStoreError::AccessDenied(key.to_owned())),
412 => return Some(ObjectStoreError::PreconditionFailed(key.to_owned())),
409 => return Some(ObjectStoreError::Conflict(key.to_owned())),
413 => {
return Some(ObjectStoreError::PayloadTooLarge {
limit_bytes: SINGLE_PUT_LIMIT_BYTES,
});
}
_ => {}
}
match code {
Some("NoSuchKey" | "NoSuchBucket" | "NotFound") => {
Some(ObjectStoreError::NotFound(key.to_owned()))
}
Some("AccessDenied") => Some(ObjectStoreError::AccessDenied(key.to_owned())),
Some("PreconditionFailed") => Some(ObjectStoreError::PreconditionFailed(key.to_owned())),
Some("ConditionalRequestConflict") => Some(ObjectStoreError::Conflict(key.to_owned())),
Some("EntityTooLarge") => Some(ObjectStoreError::PayloadTooLarge {
limit_bytes: SINGLE_PUT_LIMIT_BYTES,
}),
_ => None,
}
}
#[async_trait::async_trait]
impl ObjectStore for S3Store {
async fn list(&self, prefix: &str) -> Result<Vec<ObjectMeta>, ObjectStoreError> {
let mut out = Vec::new();
let mut token: Option<String> = None;
loop {
let resp = self
.client
.list_objects_v2()
.bucket(&self.bucket)
.prefix(prefix)
.set_continuation_token(token.take())
.send()
.await
.map_err(|e| classify(e, prefix))?;
out.reserve(resp.contents().len());
for obj in resp.contents() {
out.push(object_to_meta(obj)?);
}
if !resp.is_truncated().unwrap_or(false) {
break;
}
match resp.next_continuation_token() {
Some(t) => token = Some(t.to_owned()),
None => break,
}
}
Ok(out)
}
async fn get_to_file(
&self,
key: &str,
dest: &Path,
opts: GetOpts,
) -> Result<(), ObjectStoreError> {
let parent = dest.parent().ok_or_else(|| {
ObjectStoreError::Other(
format!("destination `{}` has no parent directory", dest.display()).into(),
)
})?;
let progress = opts.progress.as_ref();
match self.head_then_download(key, dest, parent, progress).await {
Err(ObjectStoreError::PreconditionFailed(_)) => {
tracing::warn!(key, "object changed between head and GET; retrying");
self.head_then_download(key, dest, parent, progress).await
}
other => other,
}
}
async fn get_bytes(&self, key: &str) -> Result<Bytes, ObjectStoreError> {
let resp = self
.client
.get_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|e| classify(e, key))?;
let aggregated = resp.body.collect().await.map_err(network_boxed)?;
Ok(aggregated.into_bytes())
}
async fn get_bytes_range(
&self,
key: &str,
range: std::ops::Range<u64>,
) -> Result<Bytes, ObjectStoreError> {
if let Some(empty) = super::precheck_range(key, &range)? {
return Ok(empty);
}
let inclusive_end = range.end - 1;
let result = self
.client
.get_object()
.bucket(&self.bucket)
.key(key)
.range(format!("bytes={}-{}", range.start, inclusive_end))
.send()
.await;
let resp = match result {
Ok(resp) => resp,
Err(err) => {
if let SdkError::ServiceError(svc) = &err
&& svc.raw().status().as_u16() == 416
{
return Err(ObjectStoreError::RangeNotSatisfiable {
key: key.to_owned(),
requested: range,
});
}
return Err(classify(err, key));
}
};
let aggregated = resp.body.collect().await.map_err(network_boxed)?;
super::verify_range_response_length(key, &range, aggregated.into_bytes())
}
async fn put_bytes(
&self,
key: &str,
body: Bytes,
opts: PutOpts,
) -> Result<(), ObjectStoreError> {
let size = body.len() as u64;
if should_use_multipart(size) {
return self.multipart_put_bytes(key, body, size, opts).await;
}
let progress = opts.progress.clone();
self.put_body(key, ByteStream::from(body), opts).await?;
if let Some(sink) = progress
&& size > 0
{
sink.report(size);
}
Ok(())
}
async fn put_path(&self, key: &str, src: &Path, opts: PutOpts) -> Result<(), ObjectStoreError> {
let file = tokio::fs::File::open(src).await.map_err(other_boxed)?;
let body_len = file.metadata().await.map_err(other_boxed)?.len();
if should_use_multipart(body_len) {
return self.multipart_put_path(key, file, body_len, opts).await;
}
let stream = ByteStream::read_from()
.file(file)
.length(Length::Exact(body_len))
.build()
.await
.map_err(other_boxed)?;
let progress = opts.progress.clone();
self.put_body(key, stream, opts).await?;
if let Some(sink) = progress
&& body_len > 0
{
sink.report(body_len);
}
Ok(())
}
async fn put_if_absent(&self, key: &str, body: Bytes) -> Result<bool, ObjectStoreError> {
let resp = self
.client
.put_object()
.bucket(&self.bucket)
.key(key)
.if_none_match("*")
.body(ByteStream::from(body))
.send()
.await;
match resp.map_err(|e| classify(e, key)) {
Ok(_) => Ok(true),
Err(ObjectStoreError::PreconditionFailed(_) | ObjectStoreError::Conflict(_)) => {
Ok(false)
}
Err(other) => Err(other),
}
}
async fn head(&self, key: &str) -> Result<ObjectMeta, ObjectStoreError> {
let resp = self
.client
.head_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|e| classify(e, key))?;
head_output_to_meta(
key,
resp.content_length(),
resp.last_modified(),
resp.e_tag(),
)
}
async fn copy(&self, src: &str, dst: &str) -> Result<(), ObjectStoreError> {
let meta = self.head(src).await?;
if should_use_multipart(meta.size) {
return self
.multipart_copy(src, dst, meta.size, meta.etag.as_deref())
.await;
}
let copy_source = encode_copy_source(&self.bucket, src);
let mut req = self
.client
.copy_object()
.bucket(&self.bucket)
.key(dst)
.copy_source(copy_source)
.metadata_directive(MetadataDirective::Replace);
if let Some(etag) = meta.etag.as_deref() {
req = req.copy_source_if_match(etag);
}
req.send().await.map_err(|e| classify(e, src))?;
Ok(())
}
async fn delete(&self, key: &str) -> Result<(), ObjectStoreError> {
self.head(key).await?;
self.client
.delete_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|e| classify(e, key))?;
Ok(())
}
async fn presigned_get_url(
&self,
key: &str,
ttl: std::time::Duration,
) -> Result<String, ObjectStoreError> {
let config = aws_sdk_s3::presigning::PresigningConfig::expires_in(ttl).map_err(|e| {
ObjectStoreError::Other(format!("PresigningConfig::expires_in({ttl:?}): {e}").into())
})?;
let presigned = self
.client
.get_object()
.bucket(&self.bucket)
.key(key)
.presigned(config)
.await
.map_err(|e| classify(e, key))?;
Ok(presigned.uri().to_owned())
}
}
impl S3Store {
async fn head_then_download(
&self,
key: &str,
dest: &Path,
parent: &Path,
progress: Option<&ProgressSink>,
) -> Result<(), ObjectStoreError> {
let meta = self.head(key).await?;
let temp = NamedTempFile::new_in(parent).map_err(other_boxed)?;
if meta.size == 0 {
return persist_temp(temp, dest);
}
if meta.size <= MULTIPART_THRESHOLD {
self.download_single(key, temp.path(), meta.etag.as_deref(), progress)
.await?;
} else {
self.download_multipart(key, temp.path(), meta.size, meta.etag.as_deref(), progress)
.await?;
}
persist_temp(temp, dest)
}
async fn put_body(
&self,
key: &str,
body: ByteStream,
opts: PutOpts,
) -> Result<(), ObjectStoreError> {
let mut req = self
.client
.put_object()
.bucket(&self.bucket)
.key(key)
.body(body);
if let Some(cd) = &opts.content_disposition {
req = req.content_disposition(cd);
}
for (k, v) in &opts.user_metadata {
req = req.metadata(k, v);
}
req.customize()
.config_override(
aws_sdk_s3::config::Builder::new().timeout_config(upload_timeout_config()),
)
.send()
.await
.map_err(|e| classify(e, key))?;
Ok(())
}
async fn download_single(
&self,
key: &str,
temp_path: &Path,
etag: Option<&str>,
progress: Option<&ProgressSink>,
) -> Result<(), ObjectStoreError> {
let mut req = self.client.get_object().bucket(&self.bucket).key(key);
if let Some(etag) = etag {
req = req.if_match(etag);
}
let mut resp = req.send().await.map_err(|e| classify(e, key))?;
let mut file = tokio::fs::OpenOptions::new()
.write(true)
.truncate(true)
.open(temp_path)
.await
.map_err(other_boxed)?;
while let Some(chunk) = resp.body.next().await {
let bytes = chunk.map_err(network_boxed)?;
let chunk_len = bytes.len() as u64;
file.write_all(&bytes).await.map_err(other_boxed)?;
if let Some(sink) = progress
&& chunk_len > 0
{
sink.report(chunk_len);
}
}
file.flush().await.map_err(other_boxed)?;
Ok(())
}
async fn download_multipart(
&self,
key: &str,
temp_path: &Path,
size: u64,
etag: Option<&str>,
progress: Option<&ProgressSink>,
) -> Result<(), ObjectStoreError> {
let async_file = tokio::fs::OpenOptions::new()
.write(true)
.truncate(false)
.open(temp_path)
.await
.map_err(other_boxed)?;
async_file.set_len(size).await.map_err(other_boxed)?;
let file = Arc::new(Mutex::new(async_file));
let semaphore = Arc::new(Semaphore::new(MULTIPART_MAX_CONCURRENCY));
let mut tasks: JoinSet<Result<(), ObjectStoreError>> = JoinSet::new();
let etag_owned = etag.map(str::to_owned);
let progress_owned = progress.cloned();
for (start, end) in plan_ranges(size, MULTIPART_CHUNK_SIZE) {
let client = self.client.clone();
let bucket = self.bucket.clone();
let key = key.to_owned();
let etag = etag_owned.clone();
let file = Arc::clone(&file);
let semaphore = Arc::clone(&semaphore);
let progress = progress_owned.clone();
tasks.spawn(async move {
let _permit = semaphore.acquire_owned().await.map_err(other_boxed)?;
let mut req = client
.get_object()
.bucket(&bucket)
.key(&key)
.range(format!("bytes={start}-{end}"));
if let Some(etag) = &etag {
req = req.if_match(etag);
}
let resp = req.send().await.map_err(|e| classify(e, &key))?;
let bytes = resp
.body
.collect()
.await
.map_err(network_boxed)?
.into_bytes();
let expected = end - start + 1;
if bytes.len() as u64 != expected {
return Err(ObjectStoreError::Other(
format!(
"range bytes={start}-{end} returned {} bytes, expected {expected}",
bytes.len()
)
.into(),
));
}
let chunk_len = bytes.len() as u64;
let mut f = file.lock().await;
f.seek(SeekFrom::Start(start)).await.map_err(other_boxed)?;
f.write_all(&bytes).await.map_err(other_boxed)?;
drop(f);
if let Some(sink) = &progress {
sink.report(chunk_len);
}
Ok(())
});
}
while let Some(joined) = tasks.join_next().await {
joined.map_err(other_boxed)??;
}
match Arc::try_unwrap(file) {
Ok(mutex) => {
let mut f = mutex.into_inner();
f.flush().await.map_err(other_boxed)?;
}
Err(shared) => {
let mut f = shared.lock().await;
f.flush().await.map_err(other_boxed)?;
}
}
Ok(())
}
async fn multipart_put_bytes(
&self,
key: &str,
body: Bytes,
size: u64,
opts: PutOpts,
) -> Result<(), ObjectStoreError> {
let parts = plan_upload_parts(size, MULTIPART_PUT_PART_SIZE, S3_MAX_PARTS);
let guard = self.start_multipart_upload(key, &opts).await?;
let progress = opts.progress.clone();
let result = self
.upload_parts_with_bodies(key, guard.upload_id(), &parts, progress, |part| {
slice_bytes_part(&body, part)
})
.await;
self.finish_multipart_upload(guard, result).await
}
async fn multipart_put_path(
&self,
key: &str,
file: tokio::fs::File,
size: u64,
opts: PutOpts,
) -> Result<(), ObjectStoreError> {
let parts = plan_upload_parts(size, MULTIPART_PUT_PART_SIZE, S3_MAX_PARTS);
let guard = self.start_multipart_upload(key, &opts).await?;
let progress = opts.progress.clone();
let file: Arc<std::fs::File> = Arc::new(file.into_std().await);
let result = self
.upload_parts_from_file(key, guard.upload_id(), file, &parts, progress)
.await;
self.finish_multipart_upload(guard, result).await
}
async fn multipart_copy(
&self,
src: &str,
dst: &str,
size: u64,
src_etag: Option<&str>,
) -> Result<(), ObjectStoreError> {
let parts = plan_upload_parts(size, MULTIPART_PUT_PART_SIZE, S3_MAX_PARTS);
let guard = self
.start_multipart_upload(dst, &PutOpts::default())
.await?;
let copy_source = encode_copy_source(&self.bucket, src);
let result = self
.upload_parts_via_copy(src, dst, guard.upload_id(), ©_source, src_etag, &parts)
.await;
self.finish_multipart_upload(guard, result).await
}
async fn start_multipart_upload(
&self,
key: &str,
opts: &PutOpts,
) -> Result<MultipartUploadGuard, ObjectStoreError> {
let mut req = self
.client
.create_multipart_upload()
.bucket(&self.bucket)
.key(key);
if let Some(cd) = &opts.content_disposition {
req = req.content_disposition(cd);
}
for (k, v) in &opts.user_metadata {
req = req.metadata(k, v);
}
let resp = req.send().await.map_err(|e| classify(e, key))?;
let upload_id = resp.upload_id().map(str::to_owned).ok_or_else(|| {
ObjectStoreError::Other(
format!("CreateMultipartUpload for `{key}` returned no upload-id").into(),
)
})?;
Ok(MultipartUploadGuard::new(
self.client.clone(),
self.bucket.clone(),
key.to_owned(),
upload_id,
))
}
async fn upload_parts_with_bodies<F>(
&self,
key: &str,
upload_id: &str,
parts: &[UploadPart],
progress: Option<ProgressSink>,
make_body: F,
) -> Result<Vec<CompletedPart>, ObjectStoreError>
where
F: Fn(UploadPart) -> Result<Bytes, ObjectStoreError>,
{
let semaphore = Arc::new(Semaphore::new(MULTIPART_PUT_MAX_CONCURRENCY));
let mut tasks: JoinSet<Result<CompletedPart, ObjectStoreError>> = JoinSet::new();
for (idx, part) in parts.iter().enumerate() {
let part = *part;
let part_number = i32::try_from(idx + 1)
.expect("plan_upload_parts caps parts <= S3_MAX_PARTS = 10_000");
let body = make_body(part)?;
let client = self.client.clone();
let bucket = self.bucket.clone();
let key = key.to_owned();
let upload_id = upload_id.to_owned();
let semaphore = Arc::clone(&semaphore);
let progress = progress.clone();
tasks.spawn(async move {
let _permit = semaphore.acquire_owned().await.map_err(other_boxed)?;
let resp = client
.upload_part()
.bucket(&bucket)
.key(&key)
.upload_id(&upload_id)
.part_number(part_number)
.body(ByteStream::from(body))
.customize()
.config_override(
aws_sdk_s3::config::Builder::new().timeout_config(upload_timeout_config()),
)
.send()
.await
.map_err(|e| classify(e, &key))?;
let etag = resp.e_tag().map(str::to_owned).ok_or_else(|| {
ObjectStoreError::Other(
format!("UploadPart for `{key}` part {part_number} returned no ETag")
.into(),
)
})?;
if let Some(sink) = &progress {
sink.report(part.length);
}
Ok(CompletedPart::builder()
.part_number(part_number)
.e_tag(etag)
.build())
});
}
join_completed_parts(tasks, parts.len()).await
}
async fn upload_parts_from_file(
&self,
key: &str,
upload_id: &str,
file: Arc<std::fs::File>,
parts: &[UploadPart],
progress: Option<ProgressSink>,
) -> Result<Vec<CompletedPart>, ObjectStoreError> {
let semaphore = Arc::new(Semaphore::new(MULTIPART_PUT_MAX_CONCURRENCY));
let mut tasks: JoinSet<Result<CompletedPart, ObjectStoreError>> = JoinSet::new();
for (idx, part) in parts.iter().enumerate() {
let part = *part;
let part_number = i32::try_from(idx + 1)
.expect("plan_upload_parts caps parts <= S3_MAX_PARTS = 10_000");
let client = self.client.clone();
let bucket = self.bucket.clone();
let key = key.to_owned();
let upload_id = upload_id.to_owned();
let task_file = Arc::clone(&file);
let semaphore = Arc::clone(&semaphore);
let progress = progress.clone();
tasks.spawn(async move {
let _permit = semaphore.acquire_owned().await.map_err(other_boxed)?;
let body = read_file_part(task_file, part).await?;
let resp = client
.upload_part()
.bucket(&bucket)
.key(&key)
.upload_id(&upload_id)
.part_number(part_number)
.body(ByteStream::from(body))
.customize()
.config_override(
aws_sdk_s3::config::Builder::new().timeout_config(upload_timeout_config()),
)
.send()
.await
.map_err(|e| classify(e, &key))?;
let etag = resp.e_tag().map(str::to_owned).ok_or_else(|| {
ObjectStoreError::Other(
format!("UploadPart for `{key}` part {part_number} returned no ETag")
.into(),
)
})?;
if let Some(sink) = &progress {
sink.report(part.length);
}
Ok(CompletedPart::builder()
.part_number(part_number)
.e_tag(etag)
.build())
});
}
join_completed_parts(tasks, parts.len()).await
}
async fn upload_parts_via_copy(
&self,
src: &str,
dst: &str,
upload_id: &str,
copy_source: &str,
src_etag: Option<&str>,
parts: &[UploadPart],
) -> Result<Vec<CompletedPart>, ObjectStoreError> {
let semaphore = Arc::new(Semaphore::new(MULTIPART_PUT_MAX_CONCURRENCY));
let mut tasks: JoinSet<Result<CompletedPart, ObjectStoreError>> = JoinSet::new();
for (idx, part) in parts.iter().enumerate() {
let part = *part;
let part_number = i32::try_from(idx + 1)
.expect("plan_upload_parts caps parts <= S3_MAX_PARTS = 10_000");
let client = self.client.clone();
let bucket = self.bucket.clone();
let dst = dst.to_owned();
let src_ctx = src.to_owned();
let upload_id = upload_id.to_owned();
let copy_source = copy_source.to_owned();
let src_etag = src_etag.map(str::to_owned);
let range = format!("bytes={}-{}", part.offset, part.offset + part.length - 1);
let semaphore = Arc::clone(&semaphore);
tasks.spawn(async move {
let _permit = semaphore.acquire_owned().await.map_err(other_boxed)?;
let mut req = client
.upload_part_copy()
.bucket(&bucket)
.key(&dst)
.upload_id(&upload_id)
.part_number(part_number)
.copy_source(©_source)
.copy_source_range(&range);
if let Some(etag) = &src_etag {
req = req.copy_source_if_match(etag);
}
let resp = req
.customize()
.config_override(
aws_sdk_s3::config::Builder::new().timeout_config(upload_timeout_config()),
)
.send()
.await
.map_err(|e| classify(e, &src_ctx))?;
let etag = resp
.copy_part_result()
.and_then(|r| r.e_tag())
.map(str::to_owned)
.ok_or_else(|| {
ObjectStoreError::Other(
format!(
"UploadPartCopy for `{src_ctx}` → `{dst}` part {part_number} returned no ETag"
)
.into(),
)
})?;
Ok(CompletedPart::builder()
.part_number(part_number)
.e_tag(etag)
.build())
});
}
join_completed_parts(tasks, parts.len()).await
}
async fn finish_multipart_upload(
&self,
mut guard: MultipartUploadGuard,
parts: Result<Vec<CompletedPart>, ObjectStoreError>,
) -> Result<(), ObjectStoreError> {
match parts {
Ok(parts) => {
let multipart = CompletedMultipartUpload::builder()
.set_parts(Some(parts))
.build();
self.client
.complete_multipart_upload()
.bucket(&self.bucket)
.key(guard.key())
.upload_id(guard.upload_id())
.multipart_upload(multipart)
.send()
.await
.map_err(|e| classify(e, guard.key()))?;
guard.disarm();
Ok(())
}
Err(err) => {
if let Err(abort_err) = self
.client
.abort_multipart_upload()
.bucket(&self.bucket)
.key(guard.key())
.upload_id(guard.upload_id())
.send()
.await
{
tracing::warn!(
key = %guard.key(),
upload_id = %guard.upload_id(),
?abort_err,
"AbortMultipartUpload failed; orphan upload may incur storage cost \
until lifecycle expiry",
);
}
guard.disarm();
Err(err)
}
}
}
}
struct MultipartUploadGuard {
client: aws_sdk_s3::Client,
bucket: String,
key: String,
upload_id: String,
armed: bool,
}
impl MultipartUploadGuard {
fn new(client: aws_sdk_s3::Client, bucket: String, key: String, upload_id: String) -> Self {
Self {
client,
bucket,
key,
upload_id,
armed: true,
}
}
fn upload_id(&self) -> &str {
&self.upload_id
}
fn key(&self) -> &str {
&self.key
}
fn disarm(&mut self) {
self.armed = false;
}
}
impl Drop for MultipartUploadGuard {
fn drop(&mut self) {
if !self.armed {
return;
}
let Ok(handle) = tokio::runtime::Handle::try_current() else {
tracing::warn!(
key = %self.key,
upload_id = %self.upload_id,
"MultipartUploadGuard dropped outside a tokio runtime; \
cannot dispatch AbortMultipartUpload (orphan upload may \
incur storage cost until S3 lifecycle expiry)",
);
return;
};
let client = self.client.clone();
let bucket = std::mem::take(&mut self.bucket);
let key = std::mem::take(&mut self.key);
let upload_id = std::mem::take(&mut self.upload_id);
handle.spawn(async move {
if let Err(abort_err) = client
.abort_multipart_upload()
.bucket(&bucket)
.key(&key)
.upload_id(&upload_id)
.send()
.await
{
tracing::warn!(
key = %key,
upload_id = %upload_id,
?abort_err,
"AbortMultipartUpload (drop-fire) failed; orphan upload may \
incur storage cost until S3 lifecycle expiry",
);
}
});
}
}
async fn join_completed_parts(
mut tasks: JoinSet<Result<CompletedPart, ObjectStoreError>>,
capacity: usize,
) -> Result<Vec<CompletedPart>, ObjectStoreError> {
let mut completed = Vec::with_capacity(capacity);
while let Some(joined) = tasks.join_next().await {
let part = joined.map_err(other_boxed)??;
completed.push(part);
}
completed.sort_by_key(|p| {
p.part_number()
.expect("CompletedPart built with explicit part_number")
});
Ok(completed)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::url::{AzureAddressing, RemoteFlags};
use aws_sdk_s3::primitives::DateTime;
use aws_sdk_s3::types::Object;
fn parse_endpoint(s: &str) -> Url {
Url::parse(s).expect("test endpoint URL parses")
}
#[test]
fn object_to_meta_round_trips_well_formed_object() {
let modified = DateTime::from_secs(1_700_000_000);
let obj = Object::builder()
.key("refs/heads/main/abc.bundle")
.size(42)
.last_modified(modified)
.build();
let meta = object_to_meta(&obj).expect("conversion succeeds");
assert_eq!(meta.key, "refs/heads/main/abc.bundle");
assert_eq!(meta.size, 42);
assert_eq!(meta.last_modified.unix_timestamp(), 1_700_000_000);
}
#[test]
fn object_to_meta_rejects_missing_key() {
let obj = Object::builder()
.last_modified(DateTime::from_secs(1_700_000_000))
.build();
let err = object_to_meta(&obj).expect_err("missing key must error");
match err {
ObjectStoreError::Other(inner) => {
assert!(
inner.to_string().contains("without a key"),
"error message names the failure: {inner}"
);
}
other => panic!("expected ObjectStoreError::Other for missing key, got {other:?}"),
}
}
#[test]
fn object_to_meta_rejects_missing_last_modified() {
let obj = Object::builder().key("k").size(0).build();
let err = object_to_meta(&obj).expect_err("missing last_modified must error");
match err {
ObjectStoreError::Other(inner) => {
let msg = inner.to_string();
assert!(
msg.contains("without last_modified"),
"names failure: {msg}"
);
assert!(msg.contains("`k`"), "includes the key for context: {msg}");
}
other => {
panic!("expected ObjectStoreError::Other for missing last_modified, got {other:?}")
}
}
}
#[test]
fn head_output_to_meta_round_trips_well_formed_response() {
let modified = DateTime::from_secs(1_700_000_000);
let meta = head_output_to_meta("k", Some(42), Some(&modified), Some("\"abc\""))
.expect("conversion succeeds");
assert_eq!(meta.key, "k");
assert_eq!(meta.size, 42);
assert_eq!(meta.last_modified.unix_timestamp(), 1_700_000_000);
assert_eq!(meta.etag.as_deref(), Some("\"abc\""));
}
#[test]
fn head_output_to_meta_preserves_legitimate_zero_size() {
let modified = DateTime::from_secs(1_700_000_000);
let meta = head_output_to_meta("LOCK", Some(0), Some(&modified), None)
.expect("conversion succeeds");
assert_eq!(meta.size, 0);
}
#[test]
fn head_output_to_meta_rejects_missing_content_length() {
let modified = DateTime::from_secs(1_700_000_000);
let err = head_output_to_meta("k", None, Some(&modified), None)
.expect_err("missing content-length must error");
match err {
ObjectStoreError::Other(inner) => {
let msg = inner.to_string();
assert!(msg.contains("no content-length"), "names failure: {msg}");
assert!(msg.contains("`k`"), "includes the key for context: {msg}");
}
other => {
panic!("expected ObjectStoreError::Other for missing content-length, got {other:?}")
}
}
}
#[test]
fn head_output_to_meta_rejects_missing_last_modified() {
let err = head_output_to_meta("k", Some(0), None, None)
.expect_err("missing last_modified must error");
match err {
ObjectStoreError::Other(inner) => {
let msg = inner.to_string();
assert!(msg.contains("no last_modified"), "names failure: {msg}");
assert!(msg.contains("`k`"), "includes the key for context: {msg}");
}
other => {
panic!("expected ObjectStoreError::Other for missing last_modified, got {other:?}")
}
}
}
#[test]
fn head_output_to_meta_clamps_negative_size_to_zero() {
let modified = DateTime::from_secs(1_700_000_000);
let meta =
head_output_to_meta("k", Some(-1), Some(&modified), None).expect("conversion succeeds");
assert_eq!(meta.size, 0);
}
#[test]
fn object_to_meta_clamps_negative_size_to_zero() {
let obj = Object::builder()
.key("k")
.size(-1)
.last_modified(DateTime::from_secs(1_700_000_000))
.build();
let meta = object_to_meta(&obj).expect("conversion succeeds");
assert_eq!(meta.size, 0);
}
#[test]
fn plan_ranges_zero_size_yields_empty_vec() {
assert!(plan_ranges(0, 16).is_empty());
}
#[test]
fn plan_ranges_zero_chunk_yields_empty_vec() {
assert!(plan_ranges(100, 0).is_empty());
}
#[test]
fn plan_ranges_size_one_byte() {
assert_eq!(plan_ranges(1, 16), vec![(0, 0)]);
}
#[test]
fn plan_ranges_size_below_chunk() {
assert_eq!(plan_ranges(10, 16), vec![(0, 9)]);
}
#[test]
fn plan_ranges_size_equals_chunk() {
assert_eq!(plan_ranges(16, 16), vec![(0, 15)]);
}
#[test]
fn plan_ranges_size_one_byte_above_chunk() {
assert_eq!(plan_ranges(17, 16), vec![(0, 15), (16, 16)]);
}
#[test]
fn plan_ranges_exact_multiple_of_chunk() {
assert_eq!(
plan_ranges(48, 16),
vec![(0, 15), (16, 31), (32, 47)],
"three full chunks, no leftover"
);
}
#[test]
fn plan_ranges_with_partial_final_chunk() {
assert_eq!(
plan_ranges(50, 16),
vec![(0, 15), (16, 31), (32, 47), (48, 49)]
);
}
#[test]
fn plan_ranges_handles_huge_size_without_overflow() {
let size = 6u64 * 1024 * 1024 * 1024;
let chunk = 16u64 * 1024 * 1024;
let ranges = plan_ranges(size, chunk);
assert_eq!(ranges.len(), 384);
assert_eq!(ranges.first().copied(), Some((0, chunk - 1)));
assert_eq!(ranges.last().copied(), Some((size - chunk, size - 1)));
}
#[test]
fn normalize_endpoint_path_style_strips_bucket_path() {
let url = parse_endpoint("https://s3.us-west-2.amazonaws.com/my-bucket");
let out = normalize_endpoint(&url, S3Addressing::PathStyle).unwrap();
assert_eq!(out.host_str(), Some("s3.us-west-2.amazonaws.com"));
assert_eq!(out.path(), "/");
assert!(out.query().is_none());
}
#[test]
fn normalize_endpoint_strips_query_string() {
let url = parse_endpoint("http://127.0.0.1:9000/my-bucket?addressing=path");
let out = normalize_endpoint(&url, S3Addressing::PathStyle).unwrap();
assert!(out.query().is_none(), "query must be stripped: {out}");
assert_eq!(out.path(), "/");
assert_eq!(out.host_str(), Some("127.0.0.1"));
assert_eq!(out.port(), Some(9000));
}
#[test]
fn normalize_endpoint_strips_bucket_label_for_virtual_hosted() {
let url = parse_endpoint("https://my-bucket.s3.us-west-2.amazonaws.com/");
let out = normalize_endpoint(&url, S3Addressing::VirtualHosted).unwrap();
assert_eq!(out.host_str(), Some("s3.us-west-2.amazonaws.com"));
assert_eq!(out.scheme(), "https");
assert_eq!(out.path(), "/");
}
#[test]
fn normalize_endpoint_virtual_hosted_preserves_port_and_scheme() {
let url = parse_endpoint("http://my-bucket.s3.example.com:9000/some/path?x=1");
let out = normalize_endpoint(&url, S3Addressing::VirtualHosted).unwrap();
assert_eq!(out.scheme(), "http");
assert_eq!(out.host_str(), Some("s3.example.com"));
assert_eq!(out.port(), Some(9000));
assert_eq!(out.path(), "/");
assert!(out.query().is_none());
}
#[test]
fn normalize_endpoint_dotted_bucket_virtual_hosted() {
let url = parse_endpoint("https://bucketname.com.s3.us-west-2.amazonaws.com/some/path");
let out = normalize_endpoint(&url, S3Addressing::VirtualHosted).unwrap();
assert_eq!(out.host_str(), Some("s3.us-west-2.amazonaws.com"));
assert_eq!(out.path(), "/");
assert!(out.query().is_none());
}
#[test]
fn resolve_region_flag_takes_precedence() {
let url = parse_endpoint("https://my-bucket.s3.us-west-2.amazonaws.com/");
assert_eq!(
resolve_region(&url, Some("eu-central-1")),
Some("eu-central-1".to_owned())
);
}
#[test]
fn resolve_region_extracts_from_virtual_hosted_aws_host() {
let url = parse_endpoint("https://my-bucket.s3.us-west-2.amazonaws.com/");
assert_eq!(resolve_region(&url, None), Some("us-west-2".to_owned()));
}
#[test]
fn resolve_region_extracts_from_path_style_aws_host() {
let url = parse_endpoint("https://s3.eu-west-1.amazonaws.com/my-bucket");
assert_eq!(resolve_region(&url, None), Some("eu-west-1".to_owned()));
}
#[test]
fn resolve_region_handles_legacy_hyphenated_form() {
let url = parse_endpoint("https://s3-ap-south-1.amazonaws.com/my-bucket");
assert_eq!(resolve_region(&url, None), Some("ap-south-1".to_owned()));
}
#[test]
fn resolve_region_legacy_no_segment_returns_none() {
let url = parse_endpoint("https://s3.amazonaws.com/my-bucket");
assert_eq!(resolve_region(&url, None), None);
}
#[test]
fn resolve_region_non_aws_host_defaults_to_us_east_1() {
let url = parse_endpoint("http://localhost:9000/my-bucket");
assert_eq!(resolve_region(&url, None), Some("us-east-1".to_owned()));
}
#[test]
fn resolve_region_r2_endpoint_defaults_to_us_east_1() {
let url = parse_endpoint("https://abc123.r2.cloudflarestorage.com/my-bucket");
assert_eq!(resolve_region(&url, None), Some("us-east-1".to_owned()));
}
#[test]
fn resolve_region_dotted_bucket_virtual_hosted() {
let url = parse_endpoint("https://bucketname.com.s3.us-west-2.amazonaws.com/some/path");
assert_eq!(resolve_region(&url, None), Some("us-west-2".to_owned()));
}
#[test]
fn resolve_region_china_partition_virtual_hosted() {
let url = parse_endpoint("https://my-bucket.s3.cn-north-1.amazonaws.com.cn/repo");
assert_eq!(resolve_region(&url, None), Some("cn-north-1".to_owned()));
}
#[test]
fn resolve_region_china_partition_path_style() {
let url = parse_endpoint("https://s3.cn-northwest-1.amazonaws.com.cn/my-bucket");
assert_eq!(
resolve_region(&url, None),
Some("cn-northwest-1".to_owned())
);
}
#[test]
fn encode_copy_source_preserves_slash_between_bucket_and_key() {
let out = encode_copy_source("my-bucket", "refs/heads/main/abc.bundle");
assert_eq!(out, "my-bucket/refs/heads/main/abc.bundle");
}
#[test]
fn encode_copy_source_encodes_hash_in_lock_keys() {
let out = encode_copy_source("my-bucket", "refs/heads/main/LOCK#.lock");
assert_eq!(out, "my-bucket/refs/heads/main/LOCK%23.lock");
}
#[test]
fn encode_copy_source_encodes_spaces_and_query_chars() {
let out = encode_copy_source("my-bucket", "weird key?with=stuff");
assert!(out.contains("%20"), "space encoded: {out}");
assert!(out.contains("%3F"), "? encoded: {out}");
assert!(out.contains("%3D"), "= encoded: {out}");
}
#[test]
fn encode_copy_source_passes_unreserved_through() {
let out = encode_copy_source("my.bucket-name_v1~", "abc-def_ghi.txt");
assert_eq!(out, "my.bucket-name_v1~/abc-def_ghi.txt");
}
#[test]
fn classify_404_status_is_not_found() {
assert!(matches!(
classify_status_and_code(404, None, "k"),
Some(ObjectStoreError::NotFound(s)) if s == "k"
));
}
#[test]
fn classify_403_status_is_access_denied() {
assert!(matches!(
classify_status_and_code(403, None, "k"),
Some(ObjectStoreError::AccessDenied(s)) if s == "k"
));
}
#[test]
fn classify_412_status_is_precondition_failed() {
assert!(matches!(
classify_status_and_code(412, None, "k"),
Some(ObjectStoreError::PreconditionFailed(s)) if s == "k"
));
}
#[test]
fn classify_409_status_is_conflict() {
assert!(matches!(
classify_status_and_code(409, None, "k"),
Some(ObjectStoreError::Conflict(s)) if s == "k"
));
}
#[test]
fn classify_no_such_key_code_falls_back_to_not_found() {
assert!(matches!(
classify_status_and_code(500, Some("NoSuchKey"), "k"),
Some(ObjectStoreError::NotFound(s)) if s == "k"
));
}
#[test]
fn classify_conditional_request_conflict_code_is_conflict() {
assert!(matches!(
classify_status_and_code(500, Some("ConditionalRequestConflict"), "k"),
Some(ObjectStoreError::Conflict(s)) if s == "k"
));
}
#[test]
fn classify_entity_too_large_code_is_payload_too_large() {
assert!(matches!(
classify_status_and_code(400, Some("EntityTooLarge"), "k"),
Some(ObjectStoreError::PayloadTooLarge { limit_bytes })
if limit_bytes == SINGLE_PUT_LIMIT_BYTES
));
}
#[test]
fn classify_413_status_is_payload_too_large() {
assert!(matches!(
classify_status_and_code(413, None, "k"),
Some(ObjectStoreError::PayloadTooLarge { limit_bytes })
if limit_bytes == SINGLE_PUT_LIMIT_BYTES
));
}
#[test]
fn classify_unrecognised_returns_none() {
assert!(classify_status_and_code(500, Some("InternalError"), "k").is_none());
assert!(classify_status_and_code(500, None, "k").is_none());
assert!(classify_status_and_code(400, None, "k").is_none());
assert!(classify_status_and_code(400, Some("MalformedXML"), "k").is_none());
}
fn azure_url() -> RemoteUrl {
RemoteUrl::Azure {
endpoint: parse_endpoint("https://acct.blob.core.windows.net/container"),
account: "acct".to_owned(),
container: "container".to_owned(),
prefix: None,
addressing: AzureAddressing::VirtualHosted,
flags: RemoteFlags::default(),
}
}
#[tokio::test]
async fn from_remote_url_rejects_azure() {
let result = S3Store::from_remote_url(&azure_url()).await;
match result {
Err(ObjectStoreError::Other(_)) => {}
Err(other) => panic!("expected ObjectStoreError::Other, got {other:?}"),
Ok(_) => panic!("expected Azure URL to be rejected"),
}
}
#[test]
fn resolved_path_style_minio() {
let endpoint = parse_endpoint("http://127.0.0.1:9000/my-bucket?addressing=path");
let resolved =
ResolvedS3Config::from_url_parts(&endpoint, S3Addressing::PathStyle, None, None)
.expect("resolves");
assert!(resolved.force_path_style);
assert_eq!(resolved.endpoint_url.host_str(), Some("127.0.0.1"));
assert_eq!(resolved.endpoint_url.port(), Some(9000));
assert_eq!(resolved.endpoint_url.path(), "/");
assert!(resolved.endpoint_url.query().is_none());
assert_eq!(resolved.region.as_deref(), Some("us-east-1"));
assert!(resolved.profile.is_none());
}
#[test]
fn resolved_virtual_hosted_aws_strips_bucket_and_picks_region() {
let endpoint = parse_endpoint("https://my-bucket.s3.us-west-2.amazonaws.com/");
let resolved =
ResolvedS3Config::from_url_parts(&endpoint, S3Addressing::VirtualHosted, None, None)
.expect("resolves");
assert!(!resolved.force_path_style);
assert_eq!(
resolved.endpoint_url.host_str(),
Some("s3.us-west-2.amazonaws.com")
);
assert!(
!resolved.endpoint_url.as_str().contains("my-bucket"),
"bucket label must be stripped: {}",
resolved.endpoint_url
);
assert_eq!(resolved.region.as_deref(), Some("us-west-2"));
}
#[test]
fn resolved_explicit_flags_propagate() {
let endpoint = parse_endpoint("http://127.0.0.1:9000/my-bucket");
let resolved = ResolvedS3Config::from_url_parts(
&endpoint,
S3Addressing::PathStyle,
Some("dev-profile"),
Some("eu-central-1"),
)
.expect("resolves");
assert_eq!(resolved.region.as_deref(), Some("eu-central-1"));
assert_eq!(resolved.profile.as_deref(), Some("dev-profile"));
}
#[tokio::test]
async fn build_s3_config_round_trips_resolved_decisions() {
let endpoint = parse_endpoint("http://127.0.0.1:9000/my-bucket");
let resolved =
ResolvedS3Config::from_url_parts(&endpoint, S3Addressing::PathStyle, None, None)
.expect("resolves");
let _config = build_s3_config(&resolved).await;
}
#[test]
fn timeout_constants_have_expected_values() {
assert_eq!(POOL_IDLE_TIMEOUT, Duration::from_secs(30));
assert_eq!(READ_TIMEOUT, Duration::from_secs(30));
}
#[test]
fn should_use_multipart_pins_threshold_boundary() {
use super::super::multipart::MULTIPART_PUT_THRESHOLD;
assert!(!should_use_multipart(MULTIPART_PUT_THRESHOLD - 1));
assert!(should_use_multipart(MULTIPART_PUT_THRESHOLD));
assert!(should_use_multipart(MULTIPART_PUT_THRESHOLD + 1));
assert!(should_use_multipart(6 * (1 << 30)));
}
#[test]
fn put_body_upload_override_disables_read_timeout() {
let base = TimeoutConfig::builder()
.read_timeout(Duration::from_secs(99))
.build();
let mut override_cfg = upload_timeout_config();
let merged = override_cfg.take_defaults_from(&base);
assert_eq!(
merged.read_timeout(),
None,
"upload override must disable read_timeout, not just leave it Unset",
);
}
fn test_client() -> aws_sdk_s3::Client {
let conf = aws_sdk_s3::Config::builder()
.behavior_version(BehaviorVersion::latest())
.region(Region::new("us-east-1"))
.endpoint_url("http://127.0.0.1:1/")
.build();
aws_sdk_s3::Client::from_conf(conf)
}
fn make_guard() -> MultipartUploadGuard {
MultipartUploadGuard::new(
test_client(),
"bkt".to_owned(),
"k".to_owned(),
"uid".to_owned(),
)
}
#[test]
fn multipart_upload_guard_exposes_constructor_fields() {
let mut guard = make_guard();
assert_eq!(guard.key(), "k");
assert_eq!(guard.upload_id(), "uid");
guard.disarm();
}
#[test]
fn multipart_upload_guard_disarmed_drop_outside_runtime_is_silent() {
let mut guard = make_guard();
guard.disarm();
drop(guard);
}
#[test]
fn multipart_upload_guard_armed_drop_outside_runtime_does_not_panic() {
let guard = make_guard();
drop(guard);
}
#[tokio::test]
async fn multipart_upload_guard_armed_drop_inside_runtime_spawns_abort_task() {
let guard = make_guard();
drop(guard);
for _ in 0..4 {
tokio::task::yield_now().await;
}
}
fn capture_client() -> (
aws_sdk_s3::Client,
aws_smithy_http_client::test_util::CaptureRequestReceiver,
) {
use aws_sdk_s3::config::Credentials;
let (http_client, rx) = aws_smithy_http_client::test_util::capture_request(None);
let conf = aws_sdk_s3::Config::builder()
.behavior_version(BehaviorVersion::latest())
.region(Region::new("us-east-1"))
.credentials_provider(Credentials::new(
"AKIAIOSFODNN7EXAMPLE",
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
None,
None,
"test",
))
.http_client(http_client)
.force_path_style(true)
.build();
(aws_sdk_s3::Client::from_conf(conf), rx)
}
#[tokio::test]
async fn multipart_upload_guard_drop_issues_abort_multipart_upload() {
let (client, rx) = capture_client();
let guard = MultipartUploadGuard::new(
client,
"test-bucket".to_owned(),
"test/key.pack".to_owned(),
"test-upload-id-abc123".to_owned(),
);
drop(guard);
for _ in 0..16 {
tokio::task::yield_now().await;
}
let request = rx.expect_request();
assert_eq!(
request.method(),
"DELETE",
"AbortMultipartUpload must be DELETE; got {}",
request.method(),
);
let uri = request.uri();
assert!(
uri.contains("test/key.pack"),
"captured URI must address the guard's key; got {uri}",
);
assert!(
uri.contains("uploadId=test-upload-id-abc123"),
"captured URI must carry the guard's upload-id in the \
query string; got {uri}",
);
}
}