use crate::OneIoError;
use std::fmt;
use std::sync::OnceLock;
static DOTENV_INIT: OnceLock<()> = OnceLock::new();
fn ensure_dotenv() {
DOTENV_INIT.get_or_init(|| {
let _ = dotenvy::dotenv();
});
}
#[derive(Clone)]
pub struct S3Credentials {
pub access_key: String,
pub secret_key: String,
pub session_token: Option<String>,
}
impl fmt::Debug for S3Credentials {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("S3Credentials")
.field("access_key", &self.access_key)
.field("secret_key", &"<redacted>")
.field("session_token", &"<redacted>")
.finish()
}
}
impl S3Credentials {
pub fn from_env() -> Result<Self, OneIoError> {
ensure_dotenv();
let access_key = std::env::var("AWS_ACCESS_KEY_ID")
.map_err(|_| OneIoError::NotSupported("AWS_ACCESS_KEY_ID not set".to_string()))?;
let secret_key = std::env::var("AWS_SECRET_ACCESS_KEY")
.map_err(|_| OneIoError::NotSupported("AWS_SECRET_ACCESS_KEY not set".to_string()))?;
let session_token = std::env::var("AWS_SESSION_TOKEN").ok();
Ok(S3Credentials {
access_key,
secret_key,
session_token,
})
}
}
#[derive(Clone)]
pub struct S3Config {
pub bucket: String,
pub credentials: S3Credentials,
pub endpoint: String,
pub region: String,
pub ttl: std::time::Duration,
pub multipart_chunk_size: u64,
pub multipart_threshold: u64,
}
impl fmt::Debug for S3Config {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("S3Config")
.field("bucket", &self.bucket)
.field("credentials", &self.credentials)
.field("endpoint", &self.endpoint)
.field("region", &self.region)
.field("ttl", &self.ttl)
.field("multipart_chunk_size", &self.multipart_chunk_size)
.field("multipart_threshold", &self.multipart_threshold)
.finish()
}
}
impl S3Config {
pub fn from_env(bucket: &str) -> Result<Self, OneIoError> {
ensure_dotenv();
let credentials = S3Credentials::from_env()?;
let region = std::env::var("AWS_REGION")
.or_else(|_| std::env::var("S3_REGION"))
.unwrap_or_else(|_| "us-east-1".to_string());
let endpoint = std::env::var("AWS_ENDPOINT")
.or_else(|_| std::env::var("S3_ENDPOINT"))
.unwrap_or_else(|_| format!("https://s3.{region}.amazonaws.com"));
let endpoint = normalize_endpoint(&endpoint);
let multipart_chunk_size = std::env::var("ONEIO_S3_CHUNK_SIZE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(8 * 1024 * 1024);
const DEFAULT_THRESHOLD: u64 = 5 * 1024 * 1024;
let multipart_threshold = std::env::var("ONEIO_S3_MULTIPART_THRESHOLD")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_THRESHOLD);
Ok(S3Config {
bucket: bucket.to_string(),
credentials,
endpoint,
region,
ttl: std::time::Duration::from_secs(3600),
multipart_chunk_size,
multipart_threshold,
})
}
pub fn rusty_credentials(&self) -> rusty_s3::Credentials {
match &self.credentials.session_token {
Some(token) => rusty_s3::Credentials::new_with_token(
&self.credentials.access_key,
&self.credentials.secret_key,
token,
),
None => rusty_s3::Credentials::new(
&self.credentials.access_key,
&self.credentials.secret_key,
),
}
}
pub fn rusty_bucket(&self) -> Result<rusty_s3::Bucket, OneIoError> {
let endpoint = self
.endpoint
.parse()
.map_err(|e| OneIoError::NotSupported(format!("Invalid S3 endpoint: {e}")))?;
let is_aws = self.endpoint.contains("amazonaws.com");
let has_dots = self.bucket.contains('.');
let url_style = if is_aws && !has_dots {
rusty_s3::UrlStyle::VirtualHost
} else {
rusty_s3::UrlStyle::Path
};
rusty_s3::Bucket::new(
endpoint,
url_style,
self.bucket.clone(),
self.region.clone(),
)
.map_err(|e| OneIoError::NotSupported(format!("Invalid S3 bucket config: {e:?}")))
}
}
pub(crate) fn normalize_endpoint(url: &str) -> String {
let url = url.trim();
let url = if url.starts_with("http://") || url.starts_with("https://") {
url.to_string()
} else {
format!("https://{url}")
};
url.trim_end_matches('/').to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_endpoint() {
assert_eq!(normalize_endpoint("example.com"), "https://example.com");
assert_eq!(
normalize_endpoint("http://example.com"),
"http://example.com"
);
assert_eq!(
normalize_endpoint("https://example.com/"),
"https://example.com"
);
assert_eq!(
normalize_endpoint("https://example.com/path/"),
"https://example.com/path"
);
}
}