use crate::error::{Error, Result};
use crate::retry;
#[must_use]
pub fn split_s3_path(path: &str) -> (&str, &str) {
let path = path.trim_start_matches('/');
if let Some(idx) = path.find('/') {
let (bucket, rest) = path.split_at(idx);
(bucket, &rest[1..])
} else {
(path, "")
}
}
#[allow(clippy::module_name_repetitions)]
pub trait TransferStorage: Send + Sync {
fn get_object(
&self,
bucket: &str,
key: &str,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>>> + Send + '_>>;
fn put_object(
&self,
bucket: &str,
key: &str,
body: Vec<u8>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + '_>>;
}
#[derive(Clone)]
pub struct AwsTransferStorage {
client: aws_sdk_s3::Client,
}
impl AwsTransferStorage {
#[must_use]
pub const fn new(client: aws_sdk_s3::Client) -> Self {
Self { client }
}
}
impl TransferStorage for AwsTransferStorage {
fn get_object(
&self,
bucket: &str,
key: &str,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>>> + Send + '_>> {
let client = self.client.clone();
let bucket = bucket.to_string();
let key = key.to_string();
Box::pin(async move {
retry::with_retry_and_timeout(|| {
let client = client.clone();
let bucket = bucket.clone();
let key = key.clone();
async move {
let resp = client
.get_object()
.bucket(&bucket)
.key(&key)
.send()
.await
.map_err(|e| {
let status = e.raw_response().map(|r| r.status().as_u16());
let err = match status {
Some(404) => Error::not_found(e.to_string()),
Some(429) | Some(500..=599) => Error::api(e.to_string()),
_ => Error::api_permanent(e.to_string()),
};
err.with("bucket", &bucket).with("key", &key)
})?;
let body = resp.body;
let body_bytes = body.collect().await.map_err(|e| {
Error::api(e.to_string())
.with("bucket", &bucket)
.with("key", &key)
})?;
Ok(body_bytes.to_vec())
}
})
.await
})
}
fn put_object(
&self,
bucket: &str,
key: &str,
body: Vec<u8>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + '_>> {
let client = self.client.clone();
let bucket = bucket.to_string();
let key = key.to_string();
Box::pin(async move {
retry::with_retry_and_timeout(|| {
let client = client.clone();
let bucket = bucket.clone();
let key = key.clone();
let body = aws_sdk_s3::primitives::ByteStream::from(body.clone());
async move {
client
.put_object()
.bucket(&bucket)
.key(&key)
.body(body)
.send()
.await
.map_err(|e| {
let status = e.raw_response().map(|r| r.status().as_u16());
let err = match status {
Some(429) | Some(500..=599) => Error::api(e.to_string()),
_ => Error::api_permanent(e.to_string()),
};
err.with("bucket", &bucket).with("key", &key)
})?;
Ok(())
}
})
.await
})
}
}
type MemoryStore =
std::sync::Arc<std::sync::RwLock<std::collections::HashMap<(String, String), Vec<u8>>>>;
#[derive(Clone, Debug, Default)]
pub struct MemoryTransferStorage {
store: MemoryStore,
}
impl MemoryTransferStorage {
#[allow(clippy::new_without_default)]
#[must_use]
pub fn new() -> Self {
Self {
store: std::sync::Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
}
}
pub fn inject_listing_result(&self, bucket: &str, key: &str, json_bytes: Vec<u8>) {
let mut guard = self.store.write().unwrap();
guard.insert((bucket.to_string(), key.to_string()), json_bytes);
}
#[cfg(test)]
#[must_use]
pub fn test_keys(&self) -> Vec<(String, String)> {
self.store.read().unwrap().keys().cloned().collect()
}
}
impl TransferStorage for MemoryTransferStorage {
fn get_object(
&self,
bucket: &str,
key: &str,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>>> + Send + '_>> {
let store = std::sync::Arc::clone(&self.store);
let bucket = bucket.to_string();
let key = key.to_string();
Box::pin(async move {
let guard = store.read().unwrap();
let bucket_ref = bucket.clone();
let key_ref = key.clone();
guard.get(&(bucket, key)).cloned().ok_or_else(|| {
Error::not_found("object not found")
.with("bucket", bucket_ref)
.with("key", key_ref)
})
})
}
fn put_object(
&self,
bucket: &str,
key: &str,
body: Vec<u8>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + '_>> {
let store = std::sync::Arc::clone(&self.store);
let bucket = bucket.to_string();
let key = key.to_string();
Box::pin(async move {
let mut guard = store.write().unwrap();
guard.insert((bucket, key), body);
Ok(())
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn split_s3_path_leading_slashes() {
let (b, p) = split_s3_path("/bucket/prefix/");
assert_eq!(b, "bucket");
assert_eq!(p, "prefix/");
}
#[test]
fn split_s3_path_bucket_only() {
let (b, p) = split_s3_path("my-bucket");
assert_eq!(b, "my-bucket");
assert_eq!(p, "");
}
#[test]
fn split_s3_path_bucket_and_prefix() {
let (b, p) = split_s3_path("my-bucket/staging/listings/");
assert_eq!(b, "my-bucket");
assert_eq!(p, "staging/listings/");
}
#[test]
fn split_s3_path_empty_after_trim() {
let (b, p) = split_s3_path("/");
assert_eq!(b, "");
assert_eq!(p, "");
}
#[tokio::test]
async fn memory_storage_put_then_get() {
let storage = MemoryTransferStorage::new();
storage
.put_object("b", "k", b"hello".to_vec())
.await
.unwrap();
let bytes = storage.get_object("b", "k").await.unwrap();
assert_eq!(bytes, b"hello");
}
#[tokio::test]
async fn memory_storage_get_missing_returns_error() {
let storage = MemoryTransferStorage::new();
let res = storage.get_object("b", "missing").await;
assert!(res.is_err());
assert_eq!(res.unwrap_err().kind, crate::error::ErrorKind::NotFound);
}
#[tokio::test]
async fn memory_storage_inject_listing_result_then_get() {
let storage = MemoryTransferStorage::new();
let json_bytes = br#"{"files":[],"paths":[],"truncated":false}"#.to_vec();
storage.inject_listing_result("bucket", "listings/out.json", json_bytes.clone());
let bytes = storage
.get_object("bucket", "listings/out.json")
.await
.unwrap();
assert_eq!(bytes, json_bytes);
}
#[tokio::test]
async fn memory_storage_put_overwrites() {
let storage = MemoryTransferStorage::new();
storage
.put_object("b", "k", b"first".to_vec())
.await
.unwrap();
storage
.put_object("b", "k", b"second".to_vec())
.await
.unwrap();
let bytes = storage.get_object("b", "k").await.unwrap();
assert_eq!(bytes, b"second");
}
}