use anyhow::{anyhow, Result};
use aws_sdk_s3::{primitives::ByteStream, Client};
use std::path::Path;
pub fn parse_bucket(bucket: &str) -> (String, String) {
let bucket = bucket.strip_prefix("s3://").unwrap_or(bucket);
if let Some((bucket, prefix)) = bucket.split_once('/') {
(bucket.to_string(), prefix.to_string())
} else {
(bucket.to_string(), String::new())
}
}
pub async fn create_client(endpoint: Option<&str>) -> Result<Client> {
let mut config_loader = aws_config::from_env();
if let Some(endpoint) = endpoint {
config_loader = config_loader.endpoint_url(endpoint);
}
let config = config_loader.load().await;
Ok(Client::new(&config))
}
pub async fn upload_bytes(
client: &Client,
bucket: &str,
key: &str,
data: Vec<u8>,
) -> Result<()> {
let len = data.len();
client
.put_object()
.bucket(bucket)
.key(key)
.body(ByteStream::from(data))
.send()
.await?;
tracing::debug!("Uploaded {} bytes to s3://{}/{}", len, bucket, key);
Ok(())
}
pub async fn upload_file(
client: &Client,
bucket: &str,
key: &str,
path: &Path,
) -> Result<()> {
let body = ByteStream::from_path(path).await?;
client
.put_object()
.bucket(bucket)
.key(key)
.body(body)
.send()
.await?;
tracing::debug!("Uploaded {} to s3://{}/{}", path.display(), bucket, key);
Ok(())
}
pub async fn download_bytes(
client: &Client,
bucket: &str,
key: &str,
) -> Result<Vec<u8>> {
let resp = client
.get_object()
.bucket(bucket)
.key(key)
.send()
.await?;
let data = resp.body.collect().await?.into_bytes().to_vec();
tracing::debug!("Downloaded {} bytes from s3://{}/{}", data.len(), bucket, key);
Ok(data)
}
pub async fn download_file(
client: &Client,
bucket: &str,
key: &str,
path: &Path,
) -> Result<()> {
let data = download_bytes(client, bucket, key).await?;
tokio::fs::write(path, &data).await?;
Ok(())
}
pub async fn list_objects(
client: &Client,
bucket: &str,
prefix: &str,
) -> Result<Vec<String>> {
let mut keys = Vec::new();
let mut continuation_token: Option<String> = None;
loop {
let mut req = client
.list_objects_v2()
.bucket(bucket)
.prefix(prefix);
if let Some(token) = &continuation_token {
req = req.continuation_token(token);
}
let resp = req.send().await?;
if let Some(contents) = resp.contents {
for obj in contents {
if let Some(key) = obj.key {
keys.push(key);
}
}
}
if resp.is_truncated.unwrap_or(false) {
continuation_token = resp.next_continuation_token;
} else {
break;
}
}
Ok(keys)
}
#[allow(dead_code)]
pub async fn exists(client: &Client, bucket: &str, key: &str) -> Result<bool> {
use aws_sdk_s3::error::SdkError;
use aws_sdk_s3::operation::head_object::HeadObjectError;
match client.head_object().bucket(bucket).key(key).send().await {
Ok(_) => Ok(true),
Err(SdkError::ServiceError(err)) => {
match err.err() {
HeadObjectError::NotFound(_) => Ok(false),
_ => {
let msg = format!("{:?}", err.err());
if msg.contains("NotFound") || msg.contains("404") || msg.contains("NoSuchKey") {
Ok(false)
} else {
Err(anyhow!("Failed to check object existence: {:?}", err.err()))
}
}
}
}
Err(e) => {
let msg = e.to_string();
if msg.contains("NotFound") || msg.contains("404") || msg.contains("NoSuchKey") {
Ok(false)
} else {
Err(anyhow!("Failed to check object existence: {}", e))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_bucket_simple() {
let (bucket, prefix) = parse_bucket("my-bucket");
assert_eq!(bucket, "my-bucket");
assert_eq!(prefix, "");
}
#[test]
fn test_parse_bucket_with_prefix() {
let (bucket, prefix) = parse_bucket("my-bucket/some/prefix");
assert_eq!(bucket, "my-bucket");
assert_eq!(prefix, "some/prefix");
}
#[test]
fn test_parse_bucket_s3_url() {
let (bucket, prefix) = parse_bucket("s3://my-bucket/backups");
assert_eq!(bucket, "my-bucket");
assert_eq!(prefix, "backups");
}
#[test]
fn test_parse_bucket_s3_url_no_prefix() {
let (bucket, prefix) = parse_bucket("s3://my-bucket");
assert_eq!(bucket, "my-bucket");
assert_eq!(prefix, "");
}
fn get_test_bucket() -> Option<String> {
std::env::var("WALSYNC_TEST_BUCKET").ok()
}
fn get_test_endpoint() -> Option<String> {
std::env::var("AWS_ENDPOINT_URL_S3").ok()
}
#[tokio::test]
#[ignore]
async fn test_integration_upload_download_bytes() {
let bucket = get_test_bucket().expect("WALSYNC_TEST_BUCKET not set");
let endpoint = get_test_endpoint();
let client = create_client(endpoint.as_deref()).await.unwrap();
let test_key = format!("walsync-test/{}.txt", uuid::Uuid::new_v4());
let test_data = b"Hello from walsync integration test!".to_vec();
upload_bytes(&client, &bucket, &test_key, test_data.clone())
.await
.unwrap();
let downloaded = download_bytes(&client, &bucket, &test_key).await.unwrap();
assert_eq!(downloaded, test_data);
client
.delete_object()
.bucket(&bucket)
.key(&test_key)
.send()
.await
.unwrap();
}
#[tokio::test]
#[ignore]
async fn test_integration_list_objects() {
let bucket = get_test_bucket().expect("WALSYNC_TEST_BUCKET not set");
let endpoint = get_test_endpoint();
let client = create_client(endpoint.as_deref()).await.unwrap();
let prefix = format!("walsync-test-list/{}/", uuid::Uuid::new_v4());
for i in 0..3 {
let key = format!("{}file{}.txt", prefix, i);
upload_bytes(&client, &bucket, &key, format!("content {}", i).into_bytes())
.await
.unwrap();
}
let keys = list_objects(&client, &bucket, &prefix).await.unwrap();
assert_eq!(keys.len(), 3);
for key in &keys {
client
.delete_object()
.bucket(&bucket)
.key(key)
.send()
.await
.unwrap();
}
}
#[tokio::test]
#[ignore]
async fn test_integration_exists() {
let bucket = get_test_bucket().expect("WALSYNC_TEST_BUCKET not set");
let endpoint = get_test_endpoint();
let client = create_client(endpoint.as_deref()).await.unwrap();
let test_key = format!("walsync-test/{}.txt", uuid::Uuid::new_v4());
let exists_before = exists(&client, &bucket, &test_key).await.unwrap();
assert!(!exists_before);
upload_bytes(&client, &bucket, &test_key, b"test".to_vec())
.await
.unwrap();
let exists_after = exists(&client, &bucket, &test_key).await.unwrap();
assert!(exists_after);
client
.delete_object()
.bucket(&bucket)
.key(&test_key)
.send()
.await
.unwrap();
}
}