pub mod error;
use std::{fmt::Debug, future::Future, pin::Pin, sync::Arc, time};
use aws_config::{
environment::EnvironmentVariableRegionProvider,
meta::{credentials::CredentialsProviderChain, region::RegionProviderChain},
profile::ProfileFileRegionProvider,
};
use aws_sdk_s3::{
primitives::ByteStream,
types::{CompletedMultipartUpload, CompletedPart},
};
use aws_types::region::Region;
use bytes::BytesMut;
use error::ClientError;
use tokio::{
io::AsyncReadExt,
sync::{Semaphore, mpsc},
};
use crate::{
secret::{Secret, SecretCorruptionError},
utils::to_human_readable_size,
};
const MIN_CHUNK_SIZE: u64 = 4 << 23;
const MAX_RECOMMENDED_CHUNK_SIZE: u64 = 12 << 23;
const MAX_CHUNKS: u64 = 10_000 - 10;
const LOWER_LIMIT: u64 = 8 << 30; const UPPER_LIMIT: u64 = 96 << 30;
#[derive(Debug, Clone)]
pub struct Source {
pub(crate) location: Location,
pub(crate) object: String,
}
impl Source {
pub fn new(location: Location, object: String) -> Self {
Self { location, object }
}
}
impl crate::package::source::IntoAsyncRead for aws_sdk_s3::operation::get_object::GetObjectOutput {
fn into_async_reader(self) -> impl crate::package::source::PackageStreamReader {
self.body.into_async_read()
}
}
impl<A> From<error::OpenError> for crate::package::source::Either<A, error::OpenError> {
fn from(value: error::OpenError) -> Self {
Self::Or(value)
}
}
impl crate::package::source::PackageStream for Source {
type Error = error::OpenError;
type FileStream = aws_sdk_s3::operation::get_object::GetObjectOutput;
async fn open_central_directory_stream(
&self,
) -> Result<(impl std::io::Read + std::io::Seek, u64), Self::Error> {
let object_size = self.size().await?;
let bytes_to_pull = std::cmp::min(1 << 20, object_size);
let buffer_offset = object_size - bytes_to_pull;
let data = self
.location
.s3_client()
.await?
.get_object()
.bucket(&self.location.bucket)
.key(&self.object)
.range(format!("bytes={}-{}", buffer_offset, object_size - 1))
.send()
.await?
.body
.collect()
.await?
.into_bytes();
Ok((std::io::Cursor::new(data), buffer_offset))
}
async fn open_file_header_stream(
&self,
offset: u64,
total_size: u64,
) -> Result<
Self::FileStream,
crate::package::source::Either<crate::zip::error::HeaderParseError, Self::Error>,
> {
let object = self
.location
.s3_client()
.await
.map_err(Self::Error::from)?
.get_object()
.bucket(&self.location.bucket)
.key(&self.object)
.range(format!("bytes={}-{}", offset, offset + 1024))
.send()
.await
.map_err(Self::Error::from)?;
let mut reader = object.body.into_async_read();
let header_size = Self::local_header_size(&mut reader).await?;
let object = self
.location
.s3_client()
.await
.map_err(Self::Error::from)?
.get_object()
.bucket(&self.location.bucket)
.key(&self.object)
.range(format!(
"bytes={}-{}",
offset + header_size,
offset + header_size + total_size - 1,
))
.send()
.await
.map_err(Self::Error::from)?;
Ok(object)
}
async fn open_raw(&self) -> Result<Self::FileStream, Self::Error> {
Ok(self
.location
.s3_client()
.await?
.get_object()
.bucket(&self.location.bucket)
.key(&self.object)
.send()
.await?)
}
async fn size(&self) -> Result<u64, Self::Error> {
let metadata = self
.location
.s3_client()
.await?
.head_object()
.bucket(&self.location.bucket)
.key(&self.object)
.send()
.await
.map_err(crate::remote::s3::error::get::Error)?;
let object_size = metadata
.content_length()
.ok_or(error::OpenError::MissingSizeInfo)?;
if object_size < 1 {
return Err(error::OpenError::ObjectTooSmall);
}
Ok(object_size as u64)
}
fn name(&self) -> String {
self.object.clone()
}
}
#[derive(Debug)]
pub(crate) struct Credentials {
access_key: Secret,
secret_key: Secret,
session_token: Option<Secret>,
expiration_time: Option<time::SystemTime>,
}
impl Credentials {
pub(crate) fn new(
access_key: impl Into<Secret>,
secret_key: impl Into<Secret>,
session_token: Option<impl Into<Secret>>,
expiration_time: Option<time::SystemTime>,
) -> Self {
Self {
access_key: access_key.into(),
secret_key: secret_key.into(),
session_token: session_token.map(Into::into),
expiration_time,
}
}
fn reveal_access_key(&self) -> Result<String, SecretCorruptionError> {
self.access_key.reveal("S3 access key")
}
fn reveal_secret_key(&self) -> Result<String, SecretCorruptionError> {
self.secret_key.reveal("S3 secret key")
}
fn reveal_session_token(&self) -> Result<Option<String>, SecretCorruptionError> {
self.session_token
.as_ref()
.map(|token| token.reveal("S3 STS session token"))
.transpose()
}
}
#[cfg(feature = "auth")]
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum AccessPermission {
GetObject(String),
ListObjects(u32),
PutObject(u32),
}
#[cfg(feature = "auth")]
#[derive(Debug)]
pub struct ConnectionInfo {
pub(crate) endpoint: String,
pub(crate) bucket: String,
pub(crate) credentials: Credentials,
}
#[cfg(feature = "auth")]
#[derive(Debug)]
pub struct ConnectionInfoRequest {
pub permission: AccessPermission,
pub oneshot_sender: tokio::sync::oneshot::Sender<Result<ConnectionInfo, error::ClientError>>,
}
#[cfg(feature = "auth")]
async fn get_connection_info(
permission: AccessPermission,
info_request_sender: &mpsc::UnboundedSender<ConnectionInfoRequest>,
) -> Result<ConnectionInfo, error::ClientError> {
let (tx, rx) = tokio::sync::oneshot::channel();
info_request_sender
.send(ConnectionInfoRequest {
permission,
oneshot_sender: tx,
})
.expect("mpsc channel listening for S3 info requests should not be closed");
rx.await.expect("oneshot channel should not be closed")
}
fn new_aws_s3_client(
sdk_config: aws_types::SdkConfig,
) -> Result<aws_sdk_s3::Client, proxy::error::ConnectionError> {
let s3_config_builder = aws_sdk_s3::config::Builder::from(&sdk_config).force_path_style(true);
let s3_config = if let Some(proxy_url) = proxy::get_url_from_env() {
tracing::trace!(proxy_url = %proxy_url, "Proxy URL found");
s3_config_builder.http_client(proxy::build_http_client(&proxy_url)?)
} else {
tracing::trace!("No proxy configured");
s3_config_builder
}
.build();
Ok(aws_sdk_s3::Client::from_conf(s3_config))
}
#[derive(Clone, Debug, Default)]
pub struct ClientBuilder {
endpoint: Option<String>,
region: Option<String>,
profile: Option<String>,
access_key: Option<Secret>,
secret_key: Option<Secret>,
session_token: Option<Secret>,
}
macro_rules! setter {
($name:ident, $type:ident, $doc:literal) => {
#[doc = $doc]
pub fn $name(mut self, $name: Option<impl Into<$type>>) -> Self {
self.$name = $name.map(Into::into);
self
}
};
}
impl ClientBuilder {
pub fn new() -> Self {
Self::default()
}
setter!(endpoint, String, "Sets the URL of the S3 object store.");
setter!(
region,
String,
"Sets the AWS region of the S3 object store."
);
setter!(
profile,
String,
"Sets a 'profile' name to use instead of the 'default' profile
from an S3 configuration file."
);
setter!(access_key, Secret, "Sets the S3 access key ID.");
setter!(secret_key, Secret, "Sets the S3 secret access key.");
setter!(
session_token,
Secret,
"Sets the S3 session token (needed for the edge-case where temporary
STS credentials are used)."
);
pub async fn build(self) -> Result<Client, error::ClientError> {
let region = self.region.map(Region::new);
let region_provider = RegionProviderChain::first_try(region)
.or_else(EnvironmentVariableRegionProvider::new())
.or_else(ProfileFileRegionProvider::new())
.or_else("the-shire");
let credentials_provider = if let (Some(access_key), Some(secret_key)) =
(self.access_key, self.secret_key)
{
let credentials = Credentials::new(access_key, secret_key, self.session_token, None);
CredentialsProviderChain::first_try(
"from-local-user",
aws_sdk_s3::config::Credentials::new(
credentials.reveal_access_key()?,
credentials.reveal_secret_key()?,
credentials.reveal_session_token()?,
credentials.expiration_time,
"local-user",
),
)
} else {
CredentialsProviderChain::default_provider().await
};
let mut config_loader = aws_config::from_env();
if let Some(profile) = self.profile {
config_loader = config_loader.profile_name(profile);
} else {
config_loader = config_loader
.region(region_provider)
.credentials_provider(credentials_provider);
}
if let Some(endpoint) = self.endpoint {
config_loader = config_loader.endpoint_url(endpoint);
}
let sdk_config = config_loader.load().await;
Ok(Client {
endpoint: sdk_config.endpoint_url().unwrap_or_default().into(),
inner: new_aws_s3_client(sdk_config)?,
})
}
}
#[derive(Debug, Clone)]
pub struct Client {
endpoint: String,
inner: aws_sdk_s3::Client,
}
impl S3Client for Client {
fn as_inner(
&self,
) -> Pin<Box<dyn Future<Output = Result<aws_sdk_s3::Client, ClientError>> + Send + '_>> {
Box::pin(async move { Ok(self.inner.clone()) })
}
fn endpoint(&self) -> &String {
&self.endpoint
}
}
#[cfg(feature = "auth")]
#[derive(Debug)]
struct AwsClientWithExpiry {
expiration_time: time::SystemTime,
inner: aws_sdk_s3::Client,
}
#[cfg(feature = "auth")]
impl AwsClientWithExpiry {
async fn new(
credentials: &Credentials,
endpoint: impl Into<String>,
) -> Result<Self, error::ClientError> {
let credentials_provider = CredentialsProviderChain::first_try(
"no-local-environment",
aws_sdk_s3::config::Credentials::new(
credentials.reveal_access_key()?,
credentials.reveal_secret_key()?,
credentials.reveal_session_token()?,
credentials.expiration_time,
"no-local-environment",
),
);
let sdk_config = aws_config::from_env()
.region("the-shire")
.endpoint_url(endpoint)
.credentials_provider(credentials_provider)
.load()
.await;
Ok(Self {
expiration_time: credentials
.expiration_time
.expect("Credentials must always have an expiration time"),
inner: new_aws_s3_client(sdk_config)?,
})
}
}
#[cfg(feature = "auth")]
#[derive(Debug)]
pub(crate) struct ClientAuth {
endpoint: String,
inner: tokio::sync::RwLock<AwsClientWithExpiry>,
permission: AccessPermission,
info_request_sender: mpsc::UnboundedSender<ConnectionInfoRequest>,
}
#[cfg(feature = "auth")]
impl ClientAuth {
pub(crate) async fn new(
credentials: &Credentials,
endpoint: String,
permission: AccessPermission,
info_request_sender: mpsc::UnboundedSender<ConnectionInfoRequest>,
) -> Result<Self, error::ClientError> {
Ok(Self {
endpoint: endpoint.clone(),
inner: tokio::sync::RwLock::new(AwsClientWithExpiry::new(credentials, endpoint).await?),
permission,
info_request_sender,
})
}
}
#[cfg(feature = "auth")]
impl S3Client for ClientAuth {
fn as_inner(
&self,
) -> Pin<Box<dyn Future<Output = Result<aws_sdk_s3::Client, ClientError>> + Send + '_>> {
Box::pin(async move {
let remaining_validity_time = self
.inner
.read()
.await
.expiration_time
.duration_since(time::SystemTime::now())
.unwrap_or_default();
const MAX_ATTEMPTS: u32 = 5;
const INITIAL_DELAY: time::Duration = time::Duration::from_secs(8);
const TRY_REFRESH_BEFORE: time::Duration =
super::cumulative_backoff_duration(INITIAL_DELAY, MAX_ATTEMPTS);
if remaining_validity_time < TRY_REFRESH_BEFORE {
let connection_info =
super::retry_with_backoff(MAX_ATTEMPTS, INITIAL_DELAY, || {
get_connection_info(self.permission.clone(), &self.info_request_sender)
})
.await?;
let mut write_guard = self.inner.write().await;
*write_guard =
AwsClientWithExpiry::new(&connection_info.credentials, &self.endpoint).await?;
}
Ok(self.inner.read().await.inner.clone())
})
}
fn endpoint(&self) -> &String {
&self.endpoint
}
}
pub(crate) trait S3Client: Send + Sync + Debug {
fn endpoint(&self) -> &String;
fn as_inner(
&self,
) -> Pin<Box<dyn Future<Output = Result<aws_sdk_s3::Client, ClientError>> + Send + '_>>;
fn put_object<'a>(
&'a self,
object_name: &'a str,
bucket: &'a str,
mut data: mpsc::Receiver<BytesMut>,
) -> Pin<Box<dyn Future<Output = Result<(), error::UploadError>> + Send + 'a>>
where
Self: Send + Sync,
{
Box::pin(async move {
let log_completion = |size| {
tracing::debug!(
"Successfully transferred '{}' to bucket '{}' (size: {})",
object_name,
bucket,
to_human_readable_size(size)
);
};
let mut counter: u64 = 0;
let mut part_number: i32 = 0;
let mut upload_id = String::new();
let semaphore = Arc::new(Semaphore::new(4));
let mut join_handles = Vec::new();
while let Some(chunk) = data.recv().await {
if part_number == 0 {
let chunk_len = chunk.len();
if chunk_len < chunk.capacity() {
self.as_inner()
.await?
.put_object()
.bucket(bucket)
.key(object_name)
.body(ByteStream::from(chunk.freeze()))
.send()
.await
.map_err(error::put::Error::from)?;
log_completion(chunk_len as u64);
return Ok(());
} else {
self.as_inner()
.await?
.create_multipart_upload()
.bucket(bucket)
.key(object_name)
.send()
.await
.map_err(error::put::Error::from)?
.upload_id()
.ok_or(error::put::Error::FetchMultipartId)?
.clone_into(&mut upload_id);
}
}
part_number += 1;
counter += chunk.len() as u64;
let permit = semaphore
.clone()
.acquire_owned()
.await
.map_err(error::put::Error::from);
let upload_id_copy = upload_id.clone();
let object_name_copy = object_name.to_string();
let bucket_copy = bucket.to_string();
let s3_client_copy = self.as_inner().await?;
let join_handle = tokio::task::spawn(async move {
let upload_part_res = s3_client_copy
.upload_part()
.key(object_name_copy)
.bucket(bucket_copy)
.upload_id(upload_id_copy)
.body(ByteStream::from(chunk.freeze()))
.part_number(part_number)
.send()
.await?;
let completed_part = CompletedPart::builder()
.e_tag(
upload_part_res
.e_tag
.ok_or(error::put::Error::FetchEntityTag)?,
)
.part_number(part_number)
.build();
drop(permit);
Ok(completed_part)
})
as tokio::task::JoinHandle<Result<_, error::put::Error>>;
join_handles.push(join_handle);
}
let mut upload_parts = Vec::with_capacity(join_handles.len());
for join_handle in join_handles {
upload_parts.push(join_handle.await??);
}
if upload_parts.is_empty() {
return Err(error::put::Error::EmptyUpload.into());
}
let completed_multipart_upload = CompletedMultipartUpload::builder()
.set_parts(Some(upload_parts))
.build();
self.as_inner()
.await?
.complete_multipart_upload()
.bucket(bucket)
.key(object_name)
.multipart_upload(completed_multipart_upload)
.upload_id(upload_id)
.send()
.await
.map_err(error::put::Error::from)?;
log_completion(counter);
Ok(())
})
}
}
#[derive(Debug, Clone)]
pub struct Location {
bucket: String,
client: Arc<dyn S3Client>,
}
impl Location {
pub fn new(bucket: impl Into<String>, client: Client) -> Self {
Self {
bucket: bucket.into(),
client: Arc::new(client),
}
}
#[cfg(feature = "auth")]
pub async fn new_authenticated(
permission: AccessPermission,
info_request_sender: mpsc::UnboundedSender<ConnectionInfoRequest>,
) -> Result<Self, error::ClientError> {
let connection_info = get_connection_info(permission.clone(), &info_request_sender).await?;
let client = ClientAuth::new(
&connection_info.credentials,
connection_info.endpoint,
permission,
info_request_sender,
)
.await?;
Ok(Location {
bucket: connection_info.bucket,
client: Arc::new(client),
})
}
pub(crate) fn bucket(&self) -> &str {
&self.bucket
}
pub(crate) fn bucket_path(&self) -> String {
format!(
"{}/{}",
self.client.endpoint().trim_end_matches("/"),
self.bucket()
)
}
pub fn object_path(&self, object_name: &str) -> String {
format!("{}/{}", self.bucket_path(), object_name)
}
pub(crate) async fn s3_client(&self) -> Result<aws_sdk_s3::Client, ClientError> {
self.client.as_inner().await
}
pub(crate) async fn put_object(
&self,
object_name: &str,
data: mpsc::Receiver<BytesMut>,
) -> Result<(), error::UploadError> {
self.client
.put_object(object_name, self.bucket(), data)
.await
}
}
pub(crate) async fn read_chunks_from_stream(
mut source: impl tokio::io::AsyncBufRead + Unpin + Send + Sync,
source_size: u64,
sink: mpsc::Sender<BytesMut>,
progress: Option<impl crate::progress::ProgressDisplay>,
) -> Result<(), error::ReadChunksError> {
let chunk_size = compute_chunk_size(source_size)?;
let mut buffer = BytesMut::with_capacity(chunk_size);
let mut task = progress.map(|p| p.start(source_size));
loop {
let n = source.read_buf(&mut buffer).await?;
if n == 0 || buffer.len() >= chunk_size {
sink.send(std::mem::replace(
&mut buffer,
BytesMut::with_capacity(chunk_size),
))
.await?;
}
if n == 0 {
break;
}
if let Some(t) = &mut task {
t.increment(n as u64);
}
}
Ok(())
}
mod proxy {
use aws_sdk_s3::config::SharedHttpClient;
use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder;
use hyper_proxy::{Custom, Intercept, Proxy, ProxyConnector};
use hyper_rustls::ConfigBuilderExt;
fn parse_no_proxy_entries(no_proxy_value: &str) -> Vec<String> {
no_proxy_value
.split(',')
.filter_map(|entry| {
let entry = str::trim(entry);
(!entry.is_empty()).then_some(entry.to_owned())
})
.collect()
}
fn should_bypass(entries: &[impl AsRef<str>], host: &str) -> bool {
entries.iter().any(|entry| {
let entry = entry.as_ref();
if entry == host {
return true;
}
if let Some(suffix) = entry.strip_prefix('.') {
return host == suffix || host.ends_with(entry);
}
false
})
}
fn build_intercept_from_value(entries: Vec<String>) -> Intercept {
if entries.is_empty() {
return Intercept::All;
}
if entries.iter().any(|entry| entry == "*") {
return Intercept::None;
}
Intercept::Custom(Custom::from(
move |_scheme: Option<&str>, host: Option<&str>, _port: Option<u16>| -> bool {
host.is_some_and(|h| !should_bypass(&entries, h))
},
))
}
fn build_intercept() -> Intercept {
let no_proxy = std::env::var("NO_PROXY")
.or_else(|_| std::env::var("no_proxy"))
.unwrap_or_default();
build_intercept_from_value(parse_no_proxy_entries(&no_proxy))
}
pub(super) fn get_url_from_env() -> Option<String> {
let mut proxy_url = None;
for name in [
"HTTP_PROXY",
"HTTPS_PROXY",
"ALL_PROXY",
"http_proxy",
"https_proxy",
"all_proxy",
] {
if let Ok(val) = std::env::var(name) {
tracing::trace!(source = %name, proxy_url = %val, "Found proxy URL from source");
proxy_url = Some(val);
break;
}
}
proxy_url
}
pub mod error {
#[derive(Debug)]
pub enum ConnectionError {
UrlParse(String),
Connection(std::io::Error),
}
impl std::fmt::Display for ConnectionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UrlParse(message) => write!(f, "{message}"),
Self::Connection(e) => write!(f, "{e}"),
}
}
}
impl std::error::Error for ConnectionError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::UrlParse(_) => None,
Self::Connection(source) => Some(source),
}
}
}
impl From<std::io::Error> for ConnectionError {
fn from(value: std::io::Error) -> Self {
Self::Connection(value)
}
}
}
pub(super) fn build_http_client(
proxy_url: &str,
) -> Result<SharedHttpClient, error::ConnectionError> {
tracing::trace!(proxy_url = %proxy_url, "Building HTTP client for proxy URL");
let connector = hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(
rustls::ClientConfig::builder()
.with_cipher_suites(&[
rustls::cipher_suite::TLS13_AES_256_GCM_SHA384,
rustls::cipher_suite::TLS13_AES_128_GCM_SHA256,
rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
rustls::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
])
.with_safe_default_kx_groups()
.with_safe_default_protocol_versions()
.expect("Error with the TLS configuration")
.with_native_roots()
.with_no_client_auth(),
)
.https_or_http()
.enable_http1()
.enable_http2()
.build();
let proxy_parsed_url = proxy_url
.parse()
.map_err(|e| error::ConnectionError::UrlParse(format!("{e}")))?;
let proxy =
ProxyConnector::from_proxy(connector, Proxy::new(build_intercept(), proxy_parsed_url))?;
Ok(HyperClientBuilder::new().build(proxy))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_no_proxy_entries_trimming_and_splitting() {
assert_eq!(
&parse_no_proxy_entries(" localhost, .example.com ,api.internal "),
&["localhost", ".example.com", "api.internal"]
);
}
#[test]
fn test_should_bypass_exact_and_suffix() {
assert!(should_bypass(&["localhost", ".example.com"], "localhost"));
let entries = [".example.com"];
assert!(should_bypass(&entries, "api.example.com"));
assert!(should_bypass(&entries, "example.com"));
assert!(!should_bypass(&entries, "example.org"));
}
#[test]
fn test_build_intercept_from_value_variants() {
assert!(matches!(
build_intercept_from_value(Vec::new()),
Intercept::All
));
assert!(matches!(
build_intercept_from_value(vec![".example.com".to_string(), "*".to_string()]),
Intercept::None
));
assert!(matches!(
build_intercept_from_value(vec![".example.com".to_string()]),
Intercept::Custom(_)
));
}
}
}
pub(crate) fn compute_chunk_size(data_size: u64) -> Result<usize, error::ReadChunksError> {
const {
assert!(LOWER_LIMIT < UPPER_LIMIT);
assert!(MAX_CHUNKS * MIN_CHUNK_SIZE > LOWER_LIMIT);
assert!(MAX_CHUNKS * MAX_RECOMMENDED_CHUNK_SIZE > UPPER_LIMIT);
}
let chunk_size = if data_size <= LOWER_LIMIT {
MIN_CHUNK_SIZE
} else if data_size <= UPPER_LIMIT {
MIN_CHUNK_SIZE
+ ((MAX_RECOMMENDED_CHUNK_SIZE - MIN_CHUNK_SIZE) as f64
/ (UPPER_LIMIT - LOWER_LIMIT) as f64
* (data_size - LOWER_LIMIT) as f64)
.ceil() as u64
} else {
std::cmp::max(
(data_size as f64 / MAX_CHUNKS as f64).ceil() as u64,
MAX_RECOMMENDED_CHUNK_SIZE,
)
};
Ok(chunk_size.try_into()?)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_chunk_size() {
assert_eq!(compute_chunk_size(0).unwrap(), MIN_CHUNK_SIZE as usize);
assert_eq!(compute_chunk_size(1).unwrap(), MIN_CHUNK_SIZE as usize);
assert_eq!(
compute_chunk_size(LOWER_LIMIT).unwrap(),
MIN_CHUNK_SIZE as usize
);
let chunk_size =
compute_chunk_size(((LOWER_LIMIT + UPPER_LIMIT) as f64 / 2.0).ceil() as u64).unwrap();
assert!(chunk_size > MIN_CHUNK_SIZE as usize);
assert!(chunk_size < MAX_RECOMMENDED_CHUNK_SIZE as usize);
assert_eq!(
compute_chunk_size(UPPER_LIMIT).unwrap(),
MAX_RECOMMENDED_CHUNK_SIZE as usize
);
assert_eq!(
compute_chunk_size(UPPER_LIMIT + 10_000_000).unwrap(),
MAX_RECOMMENDED_CHUNK_SIZE as usize
);
let data_size_2tb = 2 << 40;
assert!(compute_chunk_size(data_size_2tb).unwrap() > MAX_RECOMMENDED_CHUNK_SIZE as usize);
assert_eq!(compute_chunk_size(data_size_2tb).unwrap(), 220122449);
let packaged_data_size = (data_size_2tb as f64 * 1.001).ceil() as u64;
let chunk_size = compute_chunk_size(data_size_2tb).unwrap();
assert!(packaged_data_size < chunk_size as u64 * 10_000);
}
#[tokio::test]
async fn test_bucket_path() {
let bucket = "test_bucket".to_string();
let endpoint = "https://test.minio.org".to_string();
let object_name = "file.test".to_string();
let expected_bucket_path = format!("{endpoint}/{bucket}");
let expected_object_path = format!("{endpoint}/{bucket}/{object_name}");
let location = Location::new(
&bucket,
ClientBuilder::new()
.endpoint(Some(&endpoint))
.build()
.await
.unwrap(),
);
assert_eq!(location.bucket_path(), expected_bucket_path);
assert_eq!(location.object_path(&object_name), expected_object_path);
let location = Location::new(
bucket,
ClientBuilder::new()
.endpoint(Some("https://test.minio.org/".to_string()))
.build()
.await
.unwrap(),
);
assert_eq!(location.bucket_path(), expected_bucket_path);
assert_eq!(location.object_path(&object_name), expected_object_path);
}
}