use crate::error::{Error, Result};
use crate::retry;
use crate::types::{S3Bucket, S3Key};
use std::future::Future;
#[must_use]
pub fn split_s3_path(path: &str) -> (S3Bucket, S3Key) {
let path = path.trim_start_matches('/');
if let Some(idx) = path.find('/') {
let (bucket, rest) = path.split_at(idx);
(
S3Bucket::from(bucket.to_string()),
S3Key::from(rest[1..].to_string()),
)
} else {
(S3Bucket::from(path.to_string()), S3Key::from(String::new()))
}
}
#[must_use]
pub fn build_s3_key(key_prefix: &str, file_name: &str) -> String {
let key_prefix = key_prefix.trim_end_matches('/');
if key_prefix.is_empty() {
file_name.to_string()
} else {
format!("{key_prefix}/{file_name}")
}
}
#[allow(clippy::module_name_repetitions)]
pub trait TransferStorage: Send + Sync {
#[allow(clippy::manual_async_fn)]
fn get_object(&self, bucket: &str, key: &str) -> impl Future<Output = Result<Vec<u8>>> + Send;
#[allow(clippy::manual_async_fn)]
fn put_object(
&self,
bucket: &str,
key: &str,
body: Vec<u8>,
) -> impl Future<Output = Result<()>> + Send;
#[allow(clippy::manual_async_fn)]
fn delete_object(&self, bucket: &str, key: &str) -> impl Future<Output = Result<()>> + Send;
#[allow(clippy::manual_async_fn)]
fn copy_object(
&self,
source_bucket: &str,
source_key: &str,
dest_bucket: &str,
dest_key: &str,
) -> impl Future<Output = Result<()>> + Send;
}
#[derive(Clone)]
pub struct AwsTransferStorage {
client: aws_sdk_s3::Client,
}
impl AwsTransferStorage {
#[must_use]
pub const fn new(client: aws_sdk_s3::Client) -> Self {
Self { client }
}
}
impl TransferStorage for AwsTransferStorage {
async fn get_object(&self, bucket: &str, key: &str) -> Result<Vec<u8>> {
retry::with_retry_and_timeout(|| {
let client = self.client.clone();
let bucket = bucket.to_string();
let key = key.to_string();
async move {
let resp = client
.get_object()
.bucket(&bucket)
.key(&key)
.send()
.await
.map_err(|e| {
let status = e.raw_response().map(|r| r.status().as_u16());
let msg = e.to_string();
Error::from_s3_sdk_status(msg, status)
.with("bucket", &bucket)
.with("key", &key)
.with_source(e)
})?;
let body = resp.body;
let body_bytes = body.collect().await.map_err(|e| {
let msg = e.to_string();
Error::api(msg)
.with("bucket", bucket)
.with("key", key)
.with_source(e)
})?;
Ok(body_bytes.to_vec())
}
})
.await
}
async fn put_object(&self, bucket: &str, key: &str, body: Vec<u8>) -> Result<()> {
retry::with_retry_and_timeout(|| {
let client = self.client.clone();
let bucket = bucket.to_string();
let key = key.to_string();
let body = aws_sdk_s3::primitives::ByteStream::from(body.clone());
async move {
client
.put_object()
.bucket(&bucket)
.key(&key)
.body(body)
.send()
.await
.map_err(|e| {
let status = e.raw_response().map(|r| r.status().as_u16());
let msg = e.to_string();
Error::from_s3_sdk_status(msg, status)
.with("bucket", bucket)
.with("key", key)
.with_source(e)
})?;
Ok(())
}
})
.await
}
async fn delete_object(&self, bucket: &str, key: &str) -> Result<()> {
retry::with_retry_and_timeout(|| {
let client = self.client.clone();
let bucket = bucket.to_string();
let key = key.to_string();
async move {
client
.delete_object()
.bucket(&bucket)
.key(&key)
.send()
.await
.map_err(|e| {
let status = e.raw_response().map(|r| r.status().as_u16());
let msg = e.to_string();
Error::from_s3_sdk_status(msg, status)
.with("bucket", &bucket)
.with("key", &key)
.with_source(e)
})?;
Ok(())
}
})
.await
}
async fn copy_object(
&self,
source_bucket: &str,
source_key: &str,
dest_bucket: &str,
dest_key: &str,
) -> Result<()> {
let source = format!("{source_bucket}/{source_key}");
retry::with_retry_and_timeout(|| {
let client = self.client.clone();
let source = source.clone();
let dest_bucket = dest_bucket.to_string();
let dest_key = dest_key.to_string();
async move {
client
.copy_object()
.bucket(&dest_bucket)
.key(&dest_key)
.copy_source(source)
.send()
.await
.map_err(|e| {
let status = e.raw_response().map(|r| r.status().as_u16());
let msg = e.to_string();
Error::from_s3_sdk_status(msg, status)
.with("dest_bucket", &dest_bucket)
.with("dest_key", &dest_key)
.with_source(e)
})?;
Ok(())
}
})
.await
}
}
type MemoryStore =
std::sync::Arc<std::sync::RwLock<std::collections::HashMap<(String, String), Vec<u8>>>>;
#[derive(Clone, Debug, Default)]
pub struct MemoryTransferStorage {
store: MemoryStore,
}
impl MemoryTransferStorage {
#[allow(clippy::new_without_default)]
#[must_use]
pub fn new() -> Self {
Self {
store: std::sync::Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
}
}
pub fn inject_listing_result(&self, bucket: &str, key: &str, json_bytes: Vec<u8>) {
let mut guard = self.store.write().unwrap();
guard.insert((bucket.to_string(), key.to_string()), json_bytes);
}
#[cfg(test)]
#[must_use]
pub fn test_keys(&self) -> Vec<(String, String)> {
self.store.read().unwrap().keys().cloned().collect()
}
}
impl TransferStorage for MemoryTransferStorage {
async fn get_object(&self, bucket: &str, key: &str) -> Result<Vec<u8>> {
let guard = self.store.read().unwrap();
let bucket_ref = bucket.to_string();
let key_ref = key.to_string();
guard
.get(&(bucket.to_string(), key.to_string()))
.cloned()
.ok_or_else(|| {
Error::not_found("object not found")
.with("bucket", bucket_ref)
.with("key", key_ref)
})
}
async fn put_object(&self, bucket: &str, key: &str, body: Vec<u8>) -> Result<()> {
let mut guard = self.store.write().unwrap();
guard.insert((bucket.to_string(), key.to_string()), body);
Ok(())
}
async fn delete_object(&self, bucket: &str, key: &str) -> Result<()> {
let mut guard = self.store.write().unwrap();
if guard
.remove(&(bucket.to_string(), key.to_string()))
.is_some()
{
Ok(())
} else {
Err(Error::not_found("object not found")
.with("bucket", bucket)
.with("key", key))
}
}
async fn copy_object(
&self,
source_bucket: &str,
source_key: &str,
dest_bucket: &str,
dest_key: &str,
) -> Result<()> {
let bytes = self.get_object(source_bucket, source_key).await?;
self.put_object(dest_bucket, dest_key, bytes).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn split_s3_path_leading_slashes() {
let (b, p) = split_s3_path("/bucket/prefix/");
assert_eq!(b.as_str(), "bucket");
assert_eq!(p.as_str(), "prefix/");
}
#[test]
fn split_s3_path_bucket_only() {
let (b, p) = split_s3_path("my-bucket");
assert_eq!(b.as_str(), "my-bucket");
assert_eq!(p.as_str(), "");
}
#[test]
fn split_s3_path_bucket_and_prefix() {
let (b, p) = split_s3_path("my-bucket/staging/listings/");
assert_eq!(b.as_str(), "my-bucket");
assert_eq!(p.as_str(), "staging/listings/");
}
#[test]
fn split_s3_path_empty_after_trim() {
let (b, p) = split_s3_path("/");
assert_eq!(b.as_str(), "");
assert_eq!(p.as_str(), "");
}
#[async_test_macros::async_test]
async fn memory_storage_put_then_get() {
let storage = MemoryTransferStorage::new();
storage
.put_object("b", "k", b"hello".to_vec())
.await
.unwrap();
let bytes = storage.get_object("b", "k").await.unwrap();
assert_eq!(bytes, b"hello");
}
#[async_test_macros::async_test]
async fn memory_storage_get_missing_returns_error() {
let storage = MemoryTransferStorage::new();
let res = storage.get_object("b", "missing").await;
assert!(res.is_err());
assert_eq!(res.unwrap_err().kind, crate::error::ErrorKind::NotFound);
}
#[async_test_macros::async_test]
async fn memory_storage_inject_listing_result_then_get() {
let storage = MemoryTransferStorage::new();
let json_bytes = br#"{"files":[],"paths":[],"truncated":false}"#.to_vec();
storage.inject_listing_result("bucket", "listings/out.json", json_bytes.clone());
let bytes = storage
.get_object("bucket", "listings/out.json")
.await
.unwrap();
assert_eq!(bytes, json_bytes);
}
#[async_test_macros::async_test]
async fn memory_storage_put_overwrites() {
let storage = MemoryTransferStorage::new();
storage
.put_object("b", "k", b"first".to_vec())
.await
.unwrap();
storage
.put_object("b", "k", b"second".to_vec())
.await
.unwrap();
let bytes = storage.get_object("b", "k").await.unwrap();
assert_eq!(bytes, b"second");
}
#[async_test_macros::async_test]
async fn memory_storage_put_then_delete_then_get_not_found() {
let storage = MemoryTransferStorage::new();
storage
.put_object("b", "k", b"hello".to_vec())
.await
.unwrap();
storage.delete_object("b", "k").await.unwrap();
let res = storage.get_object("b", "k").await;
assert!(res.is_err());
assert_eq!(res.unwrap_err().kind, crate::error::ErrorKind::NotFound);
}
}