use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
use aws_config::BehaviorVersion;
use aws_sdk_s3::error::SdkError;
use aws_sdk_s3::primitives::ByteStream;
use bytes::Bytes;
use log::warn;
use super::Transport;
const DEFAULT_HOST: &str = "s3.amazonaws.com";
const MIN_BACKOFF: Duration = Duration::ZERO;
const MAX_BACKOFF: Duration = Duration::from_secs(5);
#[derive(Debug)]
pub struct S3Parameters {
pub aws_region: Option<String>,
pub s3_bucket: String,
pub use_ssl: bool,
pub verify: bool,
pub boto_defaults: bool,
pub compatability: bool,
}
impl Default for S3Parameters {
fn default() -> Self {
Self {
aws_region: None,
s3_bucket: "al-storage".to_string(),
use_ssl: true,
verify: true,
boto_defaults: false,
compatability: true,
}
}
}
pub struct TransportS3 {
parameters: S3Parameters,
retry_limit: Option<usize>,
client: aws_sdk_s3::Client,
base: String,
_accesskey: Option<String>,
host: String,
port: u16,
read_only: bool,
}
impl TransportS3 {
pub async fn new(base: String, host: Option<String>, port: Option<u16>, accesskey: Option<String>, secretkey: Option<String>, connection_attempts: Option<usize>, parameters: S3Parameters, read_only: bool) -> Result<Self> {
let host = host.unwrap_or_else(|| DEFAULT_HOST.to_owned());
let port = match port {
Some(port) => port,
None => if parameters.use_ssl { 443 } else { 80 }
};
let scheme = if parameters.use_ssl { "https" } else { "http" };
let endpoint_url = format!("{scheme}://{host}:{port}");
let mut loader = aws_config::defaults(BehaviorVersion::v2026_01_12());
if let Some(region) = parameters.aws_region.clone() {
loader = loader.region(aws_types::region::Region::new(region));
} else {
loader = loader.region(aws_types::region::Region::from_static("ca-central-1"))
}
loader = loader.endpoint_url(endpoint_url);
if let Some(key) = &accesskey {
std::env::set_var("AWS_ACCESS_KEY_ID", key);
}
if let Some(secret) = secretkey {
std::env::set_var("AWS_SECRET_ACCESS_KEY", secret);
}
loader = loader.http_client({
use legacy_hyper_rustls as hyper_rustls;
use legacy_rustls as rustls;
let https_connector = if parameters.verify {
hyper_rustls::HttpsConnectorBuilder::new()
.with_native_roots()
.https_or_http()
.enable_http1()
.enable_http2()
.build()
} else {
let root_store = rustls::RootCertStore::empty();
let mut tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store.clone())
.with_no_client_auth();
tls_config
.dangerous()
.set_certificate_verifier(Arc::new(verifier::NoCertificateVerification::new()));
hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(tls_config)
.https_or_http()
.enable_http1()
.enable_http2()
.build()
};
aws_smithy_http_client::hyper_014::HyperClientBuilder::new()
.build(https_connector)
});
let sdk_config = loader.load().await;
let s3_config = if parameters.compatability {
aws_sdk_s3::config::Builder::from(&sdk_config)
.force_path_style(true)
.request_checksum_calculation(aws_sdk_s3::config::RequestChecksumCalculation::WhenRequired)
.response_checksum_validation(aws_sdk_s3::config::ResponseChecksumValidation::WhenRequired)
.build()
} else {
aws_sdk_s3::config::Builder::from(&sdk_config).build()
};
let client = aws_sdk_s3::Client::from_conf(s3_config);
let head_result = retry!(connection_attempts, {
client.head_bucket().bucket(¶meters.s3_bucket).send().await
});
if let Err(err) = head_result {
let err = err.downcast::<SdkError<aws_sdk_s3::operation::head_bucket::HeadBucketError>>()?;
let err = err.into_service_error();
if err.is_not_found() {
if !read_only{
let create_result = retry!(connection_attempts, {
client.create_bucket().bucket(¶meters.s3_bucket).send().await
});
if let Err(err) = create_result {
let err = err.downcast::<SdkError<aws_sdk_s3::operation::create_bucket::CreateBucketError>>()?;
let x = err.into_service_error();
if !x.is_bucket_already_exists() && !x.is_bucket_already_owned_by_you() {
return Err(x.into())
}
}
}
} else {
return Err(anyhow::Error::new(err).context("head error"))
}
}
Ok(Self {
base,
parameters,
_accesskey: accesskey,
retry_limit: connection_attempts,
client,
host,
port,
read_only
})
}
fn normalize(&self, path: &str) -> Result<String> {
match Path::new(path).file_name() {
Some(path) => Ok(path.to_string_lossy().to_string()),
None => Err(anyhow::anyhow!("Could not normalize path to file name: {path}")),
}
}
}
impl std::fmt::Debug for TransportS3 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("s3://")?;
f.write_fmt(format_args!("{}:{}/{}", self.host, self.port, self.parameters.s3_bucket))?;
if !self.base.is_empty() {
if !self.base.starts_with("/") {
f.write_str("/")?;
}
f.write_str(&self.base)?;
}
Ok(())
}
}
#[async_trait]
impl Transport for TransportS3 {
async fn put(&self, name: &str, body: &Bytes) -> Result<()> {
let label = self.normalize(name)?;
retry!(ignore_result, self.retry_limit, {
self.client
.put_object()
.content_type("application/octet-stream")
.content_length(body.len() as i64)
.bucket(&self.parameters.s3_bucket)
.key(label.clone())
.body(body.clone().into())
.send().await
})
}
async fn upload(&self, path: &Path, name: &str) -> Result<()> {
let label = self.normalize(name)?;
retry!(ignore_result, self.retry_limit, {
self.client
.put_object()
.content_type("application/octet-stream")
.bucket(&self.parameters.s3_bucket)
.key(label.clone())
.body(ByteStream::from_path(path).await?)
.send().await
})
}
async fn get(&self, name: &str) -> Result<Option<Vec<u8>>> {
let label = self.normalize(name)?;
fn is_not_found(err: &SdkError<aws_sdk_s3::operation::get_object::GetObjectError>) -> bool {
if let Some(err) = err.as_service_error() {
if err.is_no_such_key() {
return true
}
}
return false
}
retry!(self.retry_limit, {
let request = self.client
.get_object()
.bucket(&self.parameters.s3_bucket)
.key(label.clone())
.send().await;
match request {
Ok(request) => {
let bytes = request.body.collect().await?;
Ok(Some(bytes.to_vec()))
},
Err(err) if is_not_found(&err) => Ok(None),
Err(err) => Err(err)
}
})
}
async fn exists(&self, name: &str) -> Result<bool> {
let label = self.normalize(name)?;
fn is_not_found(err: &SdkError<aws_sdk_s3::operation::head_object::HeadObjectError>) -> bool {
if let Some(err) = err.as_service_error() {
if err.is_not_found() {
return true
}
}
return false
}
retry!(self.retry_limit, {
let request = self.client
.head_object()
.bucket(&self.parameters.s3_bucket)
.key(label.clone())
.send().await;
match request {
Ok(_) => Ok(true),
Err(err) if is_not_found(&err) => Ok(false),
Err(err) => Err(err)
}
})
}
async fn stream(&self, name: &str) -> Result<(u64, tokio::sync::mpsc::Receiver<Result<Bytes, std::io::Error>>)> {
let label = self.normalize(name)?;
let mut request = self.client
.get_object()
.bucket(&self.parameters.s3_bucket)
.key(label)
.send().await?;
let length = match request.content_length() {
Some(length) => length,
None => anyhow::bail!("S3 did not return blob size"),
};
let (send, recv) = tokio::sync::mpsc::channel(64);
tokio::spawn(async move {
while let Some(buffer) = request.body.next().await {
_ = match buffer {
Ok(data) => send.send(Ok(data)).await,
Err(err) => send.send(Err(std::io::Error::other(err))).await,
};
}
});
return Ok((length as u64, recv))
}
async fn delete(&self, name: &str) -> Result<()> {
let label = self.normalize(name)?;
retry!(ignore_result, self.retry_limit, {
self.client
.delete_object()
.bucket(&self.parameters.s3_bucket)
.key(label.clone())
.send().await
})
}
fn read_only(&self) -> bool {
self.read_only
}
}
mod verifier {
use legacy_rustls::client::{ServerCertVerified, ServerCertVerifier};
#[derive(Debug)]
pub struct NoCertificateVerification { }
impl NoCertificateVerification {
pub fn new() -> Self {
Self { }
}
}
impl ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_end_entity: &legacy_rustls::Certificate,
_intermediates: &[legacy_rustls::Certificate],
_server_name: &legacy_rustls::ServerName,
_scts: &mut dyn Iterator<Item = &[u8]>,
_ocsp_response: &[u8],
_now: std::time::SystemTime,
) -> Result<legacy_rustls::client::ServerCertVerified, legacy_rustls::Error> {
Ok(ServerCertVerified::assertion())
}
}
}
macro_rules! retry {
(ignore_result, $connection_attempts: expr, $body: expr) => {
{
match retry!($connection_attempts, $body) {
Ok(_) => Ok(()),
Err(err) => Err(err)
}
}
};
($connection_attempts: expr, $body: expr) => {
{
let mut backoff = MIN_BACKOFF;
let mut retries = 0;
loop {
if retries > 0 {
tokio::time::sleep(backoff).await;
backoff = (backoff * 2).min(MAX_BACKOFF);
}
if let Some(limit) = $connection_attempts {
if retries > limit {
break Err(anyhow::Error::from(crate::errors::ConnectionError))
}
}
let ret_val = $body;
retries += 1;
match ret_val {
Ok(value) => {
if retries > 1 {
log::info!("Reconnected to S3 transport!")
}
break Ok(value)
},
Err(SdkError::TimeoutError(timeout)) => {
warn!("Connection timeout ({timeout:?}) for S3 transport, retrying...");
}
Err(SdkError::DispatchFailure(failure)) => {
warn!("Dispach failure ({failure:?}) for S3 transport, retrying...");
}
Err(SdkError::ResponseError(_)) => {
warn!("Corrupted response from S3 transport, retrying...");
}
Err(err) => {
break Err(err.into())
}
}
}
}
};
}
pub (crate) use retry;