use crate::StorageBackend;
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use aws_sdk_s3::Client;
use bytes::Bytes;
use std::fmt;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tracing::{debug, warn};
#[derive(Clone, Debug)]
pub struct S3Config {
pub bucket: String,
pub endpoint: Option<String>,
pub region: Option<String>,
pub access_key_id: Option<String>,
pub secret_access_key: Option<String>,
pub part_size: u64,
pub max_concurrent_parts: usize,
pub max_retries: u32,
pub initial_retry_delay_ms: u64,
}
impl Default for S3Config {
fn default() -> Self {
S3Config {
bucket: String::new(),
endpoint: None,
region: None,
access_key_id: None,
secret_access_key: None,
part_size: 100 * 1024 * 1024, max_concurrent_parts: 8,
max_retries: 3,
initial_retry_delay_ms: 100,
}
}
}
#[derive(Clone)]
pub struct S3Backend {
client: Client,
config: Arc<S3Config>,
stats: Arc<S3Stats>,
}
#[derive(Debug)]
struct S3Stats {
total_bytes_uploaded: AtomicU64,
total_bytes_downloaded: AtomicU64,
total_objects_deleted: AtomicU64,
}
impl S3Stats {
fn new() -> Self {
S3Stats {
total_bytes_uploaded: AtomicU64::new(0),
total_bytes_downloaded: AtomicU64::new(0),
total_objects_deleted: AtomicU64::new(0),
}
}
}
impl S3Backend {
pub async fn new(bucket: impl Into<String>) -> Result<Self> {
let config = S3Config {
bucket: bucket.into(),
..Default::default()
};
Self::with_config(config).await
}
pub async fn with_config(config: S3Config) -> Result<Self> {
let client = if let Some(endpoint) = &config.endpoint {
debug!("Using custom S3 endpoint: {}", endpoint);
let mut builder = aws_sdk_s3::config::Builder::new()
.behavior_version(aws_sdk_s3::config::BehaviorVersion::latest())
.endpoint_url(endpoint.clone())
.force_path_style(true);
if let Some(region) = &config.region {
builder = builder.region(aws_sdk_s3::config::Region::new(region.clone()));
} else {
builder = builder.region(aws_sdk_s3::config::Region::new("us-east-1"));
}
if let (Some(key_id), Some(secret)) = (&config.access_key_id, &config.secret_access_key)
{
let credentials =
aws_sdk_s3::config::Credentials::new(key_id, secret, None, None, "S3Backend");
builder = builder.credentials_provider(credentials);
}
Client::from_conf(builder.build())
} else {
let sdk_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
.load()
.await;
Client::new(&sdk_config)
};
match client.create_bucket().bucket(&config.bucket).send().await {
Ok(_) => {
debug!("S3 bucket '{}' created successfully", config.bucket);
}
Err(e) => {
let already_exists = e
.as_service_error()
.map(|se| se.is_bucket_already_owned_by_you() || se.is_bucket_already_exists())
.unwrap_or(false);
if already_exists {
debug!("S3 bucket '{}' already exists, continuing", config.bucket);
} else {
return Err(e).context(format!(
"Failed to access or create S3 bucket: {}",
config.bucket
));
}
}
}
debug!(
"Successfully connected to S3 bucket: {} with endpoint: {:?}",
config.bucket,
config.endpoint.as_deref().unwrap_or("AWS default")
);
Ok(S3Backend {
client,
config: Arc::new(config),
stats: Arc::new(S3Stats::new()),
})
}
pub async fn with_credentials(
config: S3Config,
access_key: &str,
secret_key: &str,
region: &str,
) -> Result<Self> {
use aws_sdk_s3::config::{Credentials, Region};
let credentials = Credentials::new(
access_key,
secret_key,
None, None, "mediagit-explicit-credentials",
);
let mut s3_config_builder = aws_sdk_s3::config::Builder::new()
.credentials_provider(credentials)
.region(Region::new(region.to_string()))
.force_path_style(true);
if let Some(endpoint) = &config.endpoint {
debug!("Using custom S3 endpoint with credentials: {}", endpoint);
s3_config_builder = s3_config_builder.endpoint_url(endpoint.clone());
}
let client = Client::from_conf(s3_config_builder.build());
match client.create_bucket().bucket(&config.bucket).send().await {
Ok(_) => {
debug!(
"S3-compatible bucket '{}' created successfully",
config.bucket
);
}
Err(e) => {
let err_str = e.to_string().to_lowercase();
if err_str.contains("bucketalreadyownedbyou")
|| err_str.contains("bucketalreadyexists")
|| err_str.contains("already owned")
|| err_str.contains("already exists")
{
debug!(
"S3-compatible bucket '{}' already exists, continuing",
config.bucket
);
} else {
return Err(e).context(format!(
"Failed to access or create S3-compatible bucket: {}. Check credentials and endpoint.",
config.bucket
));
}
}
}
debug!(
"Successfully connected to S3-compatible bucket: {} at {:?}",
config.bucket, config.endpoint
);
Ok(S3Backend {
client,
config: Arc::new(config),
stats: Arc::new(S3Stats::new()),
})
}
pub fn stats(&self) -> (u64, u64, u64) {
(
self.stats.total_bytes_uploaded.load(Ordering::Relaxed),
self.stats.total_bytes_downloaded.load(Ordering::Relaxed),
self.stats.total_objects_deleted.load(Ordering::Relaxed),
)
}
fn validate_key(key: &str) -> Result<()> {
if key.is_empty() {
return Err(anyhow!("key cannot be empty"));
}
if key.starts_with('/') {
return Err(anyhow!("key cannot start with '/'"));
}
Ok(())
}
async fn with_retry<F, T>(&self, mut operation: F) -> Result<T>
where
F: FnMut() -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send>>,
{
let mut retry_count = 0;
let mut delay_ms = self.config.initial_retry_delay_ms;
loop {
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
retry_count += 1;
if retry_count >= self.config.max_retries {
return Err(e)
.context(format!("Failed after {} retries", self.config.max_retries));
}
warn!(
"Operation failed (attempt {}/{}), retrying in {}ms: {}",
retry_count, self.config.max_retries, delay_ms, e
);
tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
delay_ms = (delay_ms * 2).min(10000); }
}
}
}
}
impl fmt::Debug for S3Backend {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("S3Backend")
.field("bucket", &self.config.bucket)
.field("endpoint", &self.config.endpoint)
.field("part_size", &self.config.part_size)
.field("max_concurrent_parts", &self.config.max_concurrent_parts)
.finish()
}
}
#[async_trait]
impl StorageBackend for S3Backend {
async fn get(&self, key: &str) -> Result<Vec<u8>> {
Self::validate_key(key)?;
let client = self.client.clone();
let bucket = self.config.bucket.clone();
let key_clone = key.to_string();
let stats = self.stats.clone();
self.with_retry(|| {
let client = client.clone();
let bucket = bucket.clone();
let key = key_clone.clone();
let stats = stats.clone();
Box::pin(async move {
debug!("Getting object from S3: {}", key);
let response = client
.get_object()
.bucket(&bucket)
.key(&key)
.send()
.await
.map_err(|e| anyhow!("Failed to get object: {}", e))?;
let body = response
.body
.collect()
.await
.map_err(|e| anyhow!("Failed to read object body: {}", e))?;
let data = body.into_bytes().to_vec();
stats
.total_bytes_downloaded
.fetch_add(data.len() as u64, Ordering::Relaxed);
Ok(data)
})
})
.await
}
async fn put(&self, key: &str, data: &[u8]) -> Result<()> {
Self::validate_key(key)?;
if data.len() as u64 <= self.config.part_size {
return self.put_simple(key, data).await;
}
self.put_multipart(key, data).await
}
async fn exists(&self, key: &str) -> Result<bool> {
Self::validate_key(key)?;
let client = self.client.clone();
let bucket = self.config.bucket.clone();
let key_clone = key.to_string();
self.with_retry(|| {
let client = client.clone();
let bucket = bucket.clone();
let key = key_clone.clone();
Box::pin(async move {
debug!("Checking if object exists in S3: {}", key);
match client.head_object().bucket(&bucket).key(&key).send().await {
Ok(_) => {
debug!("Object exists: {}", key);
Ok(true)
}
Err(e) => {
let error_message = e.to_string().to_lowercase();
if error_message.contains("404")
|| error_message.contains("not found")
|| error_message.contains("notfound")
|| error_message.contains("nosuchkey")
|| error_message.contains("does not exist")
|| error_message.contains("no such key")
|| (error_message.contains("service error") && error_message.len() < 50)
{
debug!("Object does not exist: {}", key);
Ok(false)
} else {
Err(anyhow!("Failed to check object existence: {}", e))
}
}
}
})
})
.await
}
async fn delete(&self, key: &str) -> Result<()> {
Self::validate_key(key)?;
let client = self.client.clone();
let bucket = self.config.bucket.clone();
let key_clone = key.to_string();
let stats = self.stats.clone();
self.with_retry(|| {
let client = client.clone();
let bucket = bucket.clone();
let key = key_clone.clone();
let stats = stats.clone();
Box::pin(async move {
debug!("Deleting object from S3: {}", key);
client
.delete_object()
.bucket(&bucket)
.key(&key)
.send()
.await
.map_err(|e| anyhow!("Failed to delete object: {}", e))?;
stats.total_objects_deleted.fetch_add(1, Ordering::Relaxed);
Ok(())
})
})
.await
}
async fn list_objects(&self, prefix: &str) -> Result<Vec<String>> {
let client = self.client.clone();
let bucket = self.config.bucket.clone();
let prefix_clone = prefix.to_string();
self.with_retry(|| {
let client = client.clone();
let bucket = bucket.clone();
let prefix = prefix_clone.clone();
Box::pin(async move {
debug!("Listing objects in S3 with prefix: '{}'", prefix);
let mut result = vec![];
let mut continuation_token: Option<String> = None;
loop {
let mut request = client.list_objects_v2().bucket(&bucket);
if !prefix.is_empty() {
request = request.prefix(&prefix);
}
if let Some(token) = continuation_token {
request = request.continuation_token(token);
}
let response = request
.send()
.await
.map_err(|e| anyhow!("Failed to list objects: {}", e))?;
for obj in response.contents() {
if let Some(key) = obj.key() {
result.push(key.to_string());
}
}
if response.is_truncated() == Some(true) {
continuation_token =
response.next_continuation_token().map(|t| t.to_string());
} else {
break;
}
}
result.sort();
debug!("Found {} objects with prefix: '{}'", result.len(), prefix);
Ok(result)
})
})
.await
}
}
impl S3Backend {
async fn put_simple(&self, key: &str, data: &[u8]) -> Result<()> {
debug!("Putting small object to S3: {} ({} bytes)", key, data.len());
let client = self.client.clone();
let bucket = self.config.bucket.clone();
let key_clone = key.to_string();
let data_vec = data.to_vec();
let stats = self.stats.clone();
self.with_retry(|| {
let client = client.clone();
let bucket = bucket.clone();
let key = key_clone.clone();
let data = data_vec.clone();
let stats = stats.clone();
Box::pin(async move {
client
.put_object()
.bucket(&bucket)
.key(&key)
.body(Bytes::from(data.clone()).into())
.send()
.await
.map_err(|e| anyhow!("Failed to put object: {}", e))?;
stats
.total_bytes_uploaded
.fetch_add(data.len() as u64, Ordering::Relaxed);
debug!("Successfully put object to S3: {}", key);
Ok(())
})
})
.await
}
async fn put_multipart(&self, key: &str, data: &[u8]) -> Result<()> {
debug!(
"Putting large object to S3 (multipart): {} ({} bytes)",
key,
data.len()
);
let client = self.client.clone();
let bucket = self.config.bucket.clone();
let key_clone = key.to_string();
let multipart = client
.create_multipart_upload()
.bucket(&bucket)
.key(&key_clone)
.send()
.await
.map_err(|e| anyhow!("Failed to initiate multipart upload: {}", e))?;
let upload_id = multipart
.upload_id()
.ok_or_else(|| anyhow!("No upload ID returned from S3"))?
.to_string();
debug!(
"Initiated multipart upload for {}: {}",
key_clone, upload_id
);
let mut part_handles = vec![];
let part_size = self.config.part_size as usize;
let mut part_number = 1;
for chunk in data.chunks(part_size) {
let client = client.clone();
let bucket = bucket.clone();
let key = key_clone.clone();
let upload_id = upload_id.clone();
let stats = self.stats.clone();
let chunk_data = chunk.to_vec();
let part_num = part_number;
let handle = tokio::spawn(async move {
debug!(
"Uploading part {} ({} bytes) for key: {}",
part_num,
chunk_data.len(),
key
);
let response = client
.upload_part()
.bucket(&bucket)
.key(&key)
.upload_id(&upload_id)
.part_number(part_num)
.body(Bytes::from(chunk_data.clone()).into())
.send()
.await
.map_err(|e| anyhow!("Failed to upload part {}: {}", part_num, e))?;
let etag = response
.e_tag()
.ok_or_else(|| anyhow!("No ETag returned for part {}", part_num))?
.to_string();
stats
.total_bytes_uploaded
.fetch_add(chunk_data.len() as u64, Ordering::Relaxed);
Ok::<_, anyhow::Error>((part_num, etag))
});
part_handles.push(handle);
if part_handles.len() >= self.config.max_concurrent_parts {
if let Some(handle) = part_handles.pop() {
let _ = handle.await??;
}
}
part_number += 1;
}
let mut parts = vec![];
for handle in part_handles {
let (part_num, etag) = handle.await??;
parts.push((part_num, etag));
}
parts.sort_by_key(|p| p.0);
let part_list: Vec<_> = parts
.into_iter()
.map(|(part_num, etag)| {
aws_sdk_s3::types::CompletedPart::builder()
.part_number(part_num)
.e_tag(etag)
.build()
})
.collect();
client
.complete_multipart_upload()
.bucket(&bucket)
.key(&key_clone)
.upload_id(&upload_id)
.multipart_upload(
aws_sdk_s3::types::CompletedMultipartUpload::builder()
.set_parts(Some(part_list))
.build(),
)
.send()
.await
.map_err(|e| anyhow!("Failed to complete multipart upload: {}", e))?;
debug!("Successfully completed multipart upload for {}", key_clone);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = S3Config::default();
assert_eq!(config.part_size, 100 * 1024 * 1024);
assert_eq!(config.max_concurrent_parts, 8);
assert_eq!(config.max_retries, 3);
assert_eq!(config.initial_retry_delay_ms, 100);
}
#[test]
fn test_validate_key() {
assert!(S3Backend::validate_key("valid_key").is_ok());
assert!(S3Backend::validate_key("path/to/key").is_ok());
assert!(S3Backend::validate_key("").is_err());
assert!(S3Backend::validate_key("/invalid").is_err());
}
#[test]
fn test_debug_impl() {
let config = S3Config {
bucket: "test-bucket".to_string(),
..Default::default()
};
let _ = format!("{:?}", config);
}
}