use async_trait::async_trait;
use aws_sdk_s3::primitives::ByteStream;
use aws_sdk_s3::Client as S3Client;
use hashtree_core::store::{Store, StoreError};
use hashtree_core::types::{to_hex, Hash};
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone)]
pub struct S3Config {
pub bucket: String,
pub prefix: Option<String>,
pub region: Option<String>,
pub endpoint: Option<String>,
}
enum SyncMessage {
Upload { hash: Hash, data: Vec<u8> },
Delete { hash: Hash },
Shutdown,
}
pub struct S3Store<L: Store> {
local: Arc<L>,
s3_client: S3Client,
bucket: String,
prefix: String,
sync_tx: mpsc::UnboundedSender<SyncMessage>,
}
impl<L: Store + 'static> S3Store<L> {
pub async fn new(local: Arc<L>, config: S3Config) -> Result<Self, S3StoreError> {
let mut aws_config_loader = aws_config::from_env();
if let Some(ref region) = config.region {
aws_config_loader =
aws_config_loader.region(aws_sdk_s3::config::Region::new(region.clone()));
}
let aws_config = aws_config_loader.load().await;
let mut s3_config_builder = aws_sdk_s3::config::Builder::from(&aws_config);
if let Some(ref endpoint) = config.endpoint {
s3_config_builder = s3_config_builder
.endpoint_url(endpoint)
.force_path_style(true); }
let s3_client = S3Client::from_conf(s3_config_builder.build());
let prefix = config.prefix.unwrap_or_default();
let bucket = config.bucket.clone();
let (sync_tx, sync_rx) = mpsc::unbounded_channel();
let sync_client = s3_client.clone();
let sync_bucket = bucket.clone();
let sync_prefix = prefix.clone();
tokio::spawn(async move {
Self::sync_task(sync_rx, sync_client, sync_bucket, sync_prefix).await;
});
info!(
"S3Store initialized with bucket: {}, prefix: {}",
bucket, prefix
);
Ok(Self {
local,
s3_client,
bucket,
prefix,
sync_tx,
})
}
async fn sync_task(
mut rx: mpsc::UnboundedReceiver<SyncMessage>,
client: S3Client,
bucket: String,
prefix: String,
) {
info!("S3 sync task started");
while let Some(msg) = rx.recv().await {
match msg {
SyncMessage::Upload { hash, data } => {
let key = format!("{}{}", prefix, to_hex(&hash));
debug!(
"S3 uploading {} ({} bytes)",
&key[..16.min(key.len())],
data.len()
);
match client
.put_object()
.bucket(&bucket)
.key(&key)
.body(ByteStream::from(data))
.send()
.await
{
Ok(_) => {
debug!("S3 upload complete: {}", &key[..16.min(key.len())]);
}
Err(e) => {
error!("S3 upload failed for {}: {}", &key[..16.min(key.len())], e);
}
}
}
SyncMessage::Delete { hash } => {
let key = format!("{}{}", prefix, to_hex(&hash));
debug!("S3 deleting {}", &key[..16.min(key.len())]);
match client
.delete_object()
.bucket(&bucket)
.key(&key)
.send()
.await
{
Ok(_) => {
debug!("S3 delete complete: {}", &key[..16.min(key.len())]);
}
Err(e) => {
error!("S3 delete failed for {}: {}", &key[..16.min(key.len())], e);
}
}
}
SyncMessage::Shutdown => {
info!("S3 sync task shutting down");
break;
}
}
}
}
fn s3_key(&self, hash: &Hash) -> String {
format!("{}{}", self.prefix, to_hex(hash))
}
async fn fetch_from_s3(&self, hash: &Hash) -> Result<Option<Vec<u8>>, S3StoreError> {
let key = self.s3_key(hash);
match self
.s3_client
.get_object()
.bucket(&self.bucket)
.key(&key)
.send()
.await
{
Ok(output) => {
let data = output
.body
.collect()
.await
.map_err(|e| S3StoreError::S3(format!("Failed to read body: {}", e)))?;
Ok(Some(data.into_bytes().to_vec()))
}
Err(e) => {
let service_err = e.into_service_error();
if service_err.is_no_such_key() {
Ok(None)
} else {
Err(S3StoreError::S3(format!("S3 get failed: {}", service_err)))
}
}
}
}
async fn exists_in_s3(&self, hash: &Hash) -> Result<bool, S3StoreError> {
let key = self.s3_key(hash);
match self
.s3_client
.head_object()
.bucket(&self.bucket)
.key(&key)
.send()
.await
{
Ok(_) => Ok(true),
Err(e) => {
let service_err = e.into_service_error();
if service_err.is_not_found() {
Ok(false)
} else {
Err(S3StoreError::S3(format!("S3 head failed: {}", service_err)))
}
}
}
}
fn queue_upload(&self, hash: Hash, data: Vec<u8>) {
if let Err(e) = self.sync_tx.send(SyncMessage::Upload { hash, data }) {
warn!("Failed to queue S3 upload: {}", e);
}
}
fn queue_delete(&self, hash: Hash) {
if let Err(e) = self.sync_tx.send(SyncMessage::Delete { hash }) {
warn!("Failed to queue S3 delete: {}", e);
}
}
pub fn shutdown(&self) {
let _ = self.sync_tx.send(SyncMessage::Shutdown);
}
}
impl<L: Store> Drop for S3Store<L> {
fn drop(&mut self) {
let _ = self.sync_tx.send(SyncMessage::Shutdown);
}
}
#[async_trait]
impl<L: Store + 'static> Store for S3Store<L> {
async fn put(&self, hash: Hash, data: Vec<u8>) -> Result<bool, StoreError> {
let is_new = self.local.put(hash, data.clone()).await?;
if is_new {
self.queue_upload(hash, data);
}
Ok(is_new)
}
async fn get(&self, hash: &Hash) -> Result<Option<Vec<u8>>, StoreError> {
if let Some(data) = self.local.get(hash).await? {
return Ok(Some(data));
}
match self.fetch_from_s3(hash).await {
Ok(Some(data)) => {
let _ = self.local.put(*hash, data.clone()).await;
Ok(Some(data))
}
Ok(None) => Ok(None),
Err(e) => {
warn!("S3 fetch failed, returning None: {}", e);
Ok(None)
}
}
}
async fn has(&self, hash: &Hash) -> Result<bool, StoreError> {
if self.local.has(hash).await? {
return Ok(true);
}
match self.exists_in_s3(hash).await {
Ok(exists) => Ok(exists),
Err(e) => {
warn!("S3 exists check failed, returning false: {}", e);
Ok(false)
}
}
}
async fn delete(&self, hash: &Hash) -> Result<bool, StoreError> {
let deleted = self.local.delete(hash).await?;
self.queue_delete(*hash);
Ok(deleted)
}
}
#[derive(Debug, thiserror::Error)]
pub enum S3StoreError {
#[error("S3 error: {0}")]
S3(String),
#[error("Configuration error: {0}")]
Config(String),
}
#[cfg(test)]
mod tests {
use super::*;
use hashtree_core::hash::sha256;
#[test]
fn test_s3_key_generation() {
let prefix = "blobs/";
let hash = sha256(b"test");
let key = format!("{}{}", prefix, to_hex(&hash));
assert!(key.starts_with("blobs/"));
assert_eq!(key.len(), 6 + 64); }
#[test]
fn test_s3_config() {
let config = S3Config {
bucket: "test-bucket".to_string(),
prefix: Some("data/".to_string()),
region: Some("us-east-1".to_string()),
endpoint: None,
};
assert_eq!(config.bucket, "test-bucket");
assert_eq!(config.prefix, Some("data/".to_string()));
}
}