use crate::traits::BlockStore;
use async_trait::async_trait;
use aws_sdk_s3::primitives::ByteStream;
use aws_sdk_s3::types::{Delete, ObjectIdentifier};
use aws_sdk_s3::Client as S3Client;
use futures::future::join_all;
use ipfrs_core::{Block, Cid, Error, Result};
use std::sync::Arc;
use tokio::sync::Semaphore;
#[derive(Debug, Clone)]
pub struct S3Config {
pub bucket: String,
pub prefix: Option<String>,
pub region: Option<String>,
pub endpoint: Option<String>,
pub multipart_threshold: usize,
pub max_concurrent: usize,
}
impl S3Config {
pub fn new(bucket: String) -> Self {
Self {
bucket,
prefix: None,
region: None,
endpoint: None,
multipart_threshold: 5 * 1024 * 1024, max_concurrent: 10,
}
}
pub fn with_prefix(mut self, prefix: String) -> Self {
self.prefix = Some(prefix);
self
}
pub fn with_region(mut self, region: String) -> Self {
self.region = Some(region);
self
}
pub fn with_endpoint(mut self, endpoint: String) -> Self {
self.endpoint = Some(endpoint);
self
}
pub fn with_multipart_threshold(mut self, threshold: usize) -> Self {
self.multipart_threshold = threshold;
self
}
pub fn with_max_concurrent(mut self, max: usize) -> Self {
self.max_concurrent = max;
self
}
fn build_key(&self, cid: &Cid) -> String {
let cid_str = cid.to_string();
match &self.prefix {
Some(prefix) => format!("{prefix}{cid_str}"),
None => cid_str,
}
}
}
#[derive(Clone)]
pub struct S3BlockStore {
client: Arc<S3Client>,
config: S3Config,
}
impl S3BlockStore {
pub async fn new(config: S3Config) -> Result<Self> {
let aws_config = if let Some(endpoint) = &config.endpoint {
let mut builder = aws_config::defaults(aws_config::BehaviorVersion::latest());
if let Some(region) = &config.region {
builder = builder.region(aws_sdk_s3::config::Region::new(region.clone()));
}
let aws_config = builder.load().await;
aws_sdk_s3::config::Builder::from(&aws_config)
.endpoint_url(endpoint)
.force_path_style(true) .build()
} else {
let mut builder = aws_config::defaults(aws_config::BehaviorVersion::latest());
if let Some(region) = &config.region {
builder = builder.region(aws_sdk_s3::config::Region::new(region.clone()));
}
let aws_config = builder.load().await;
aws_sdk_s3::config::Builder::from(&aws_config).build()
};
let client = S3Client::from_conf(aws_config);
Ok(Self {
client: Arc::new(client),
config,
})
}
pub fn client(&self) -> &Arc<S3Client> {
&self.client
}
pub fn config(&self) -> &S3Config {
&self.config
}
async fn put_simple(&self, key: &str, data: &[u8]) -> Result<()> {
self.client
.put_object()
.bucket(&self.config.bucket)
.key(key)
.body(ByteStream::from(data.to_vec()))
.send()
.await
.map_err(|e| Error::Storage(format!("Failed to put block to S3: {e}")))?;
Ok(())
}
async fn put_multipart(&self, key: &str, data: &[u8]) -> Result<()> {
let base_part_size = 5 * 1024 * 1024; let data_size = data.len();
let part_size = if data_size > 1024 * 1024 * 1024 {
std::cmp::max(10 * 1024 * 1024, self.config.multipart_threshold)
} else if data_size > 100 * 1024 * 1024 {
std::cmp::max(8 * 1024 * 1024, self.config.multipart_threshold)
} else {
std::cmp::max(base_part_size, self.config.multipart_threshold)
};
let multipart = self
.client
.create_multipart_upload()
.bucket(&self.config.bucket)
.key(key)
.send()
.await
.map_err(|e| Error::Storage(format!("Failed to initiate multipart upload: {e}")))?;
let upload_id = multipart
.upload_id()
.ok_or_else(|| Error::Storage("No upload ID returned".to_string()))?;
let chunks: Vec<_> = data.chunks(part_size).collect();
let semaphore = Arc::new(Semaphore::new(self.config.max_concurrent));
let mut futures = Vec::new();
for (part_number, chunk) in chunks.iter().enumerate() {
let part_num = (part_number + 1) as i32;
let data_chunk = chunk.to_vec();
let client = self.client.clone();
let bucket = self.config.bucket.clone();
let key = key.to_string();
let upload_id = upload_id.to_string();
let sem = semaphore.clone();
let future = async move {
let _permit = sem.acquire().await.unwrap();
let mut attempts = 0;
let max_attempts = 3;
let mut last_error = None;
while attempts < max_attempts {
match client
.upload_part()
.bucket(&bucket)
.key(&key)
.upload_id(&upload_id)
.part_number(part_num)
.body(ByteStream::from(data_chunk.clone()))
.send()
.await
{
Ok(output) => {
return Ok((part_num, output.e_tag().unwrap_or_default().to_string()));
}
Err(e) => {
attempts += 1;
last_error = Some(e);
if attempts < max_attempts {
tokio::time::sleep(tokio::time::Duration::from_millis(
100 * (1 << (attempts - 1)),
))
.await;
}
}
}
}
Err(last_error.unwrap())
};
futures.push(future);
}
let results = join_all(futures).await;
let mut upload_parts = Vec::new();
for result in results {
match result {
Ok((part_number, etag)) => {
upload_parts.push((
part_number,
aws_sdk_s3::types::CompletedPart::builder()
.part_number(part_number)
.e_tag(etag)
.build(),
));
}
Err(e) => {
let _ = self
.client
.abort_multipart_upload()
.bucket(&self.config.bucket)
.key(key)
.upload_id(upload_id)
.send()
.await;
return Err(Error::Storage(format!(
"Failed to upload part after retries: {e}"
)));
}
}
}
upload_parts.sort_by_key(|(part_num, _)| *part_num);
let sorted_parts: Vec<_> = upload_parts.into_iter().map(|(_, part)| part).collect();
let completed_upload = aws_sdk_s3::types::CompletedMultipartUpload::builder()
.set_parts(Some(sorted_parts))
.build();
self.client
.complete_multipart_upload()
.bucket(&self.config.bucket)
.key(key)
.upload_id(upload_id)
.multipart_upload(completed_upload)
.send()
.await
.map_err(|e| Error::Storage(format!("Failed to complete multipart upload: {e}")))?;
Ok(())
}
}
#[async_trait]
impl BlockStore for S3BlockStore {
async fn put(&self, block: &Block) -> Result<()> {
let key = self.config.build_key(block.cid());
let data = block.data();
if data.len() >= self.config.multipart_threshold {
self.put_multipart(&key, data).await
} else {
self.put_simple(&key, data).await
}
}
async fn get(&self, cid: &Cid) -> Result<Option<Block>> {
let key = self.config.build_key(cid);
match self
.client
.get_object()
.bucket(&self.config.bucket)
.key(&key)
.send()
.await
{
Ok(output) => {
let data = output
.body
.collect()
.await
.map_err(|e| Error::Storage(format!("Failed to read S3 object body: {e}")))?
.into_bytes();
Ok(Some(Block::from_parts(*cid, data)))
}
Err(e) => {
let error_str = e.to_string();
if error_str.contains("NoSuchKey") || error_str.contains("404") {
Ok(None)
} else {
Err(Error::Storage(format!("Failed to get block from S3: {e}")))
}
}
}
}
async fn has(&self, cid: &Cid) -> Result<bool> {
let key = self.config.build_key(cid);
match self
.client
.head_object()
.bucket(&self.config.bucket)
.key(&key)
.send()
.await
{
Ok(_) => Ok(true),
Err(e) => {
let error_str = e.to_string();
if error_str.contains("NotFound") || error_str.contains("404") {
Ok(false)
} else {
Err(Error::Storage(format!(
"Failed to check block existence in S3: {e}"
)))
}
}
}
}
async fn delete(&self, cid: &Cid) -> Result<()> {
let key = self.config.build_key(cid);
self.client
.delete_object()
.bucket(&self.config.bucket)
.key(&key)
.send()
.await
.map_err(|e| Error::Storage(format!("Failed to delete block from S3: {e}")))?;
Ok(())
}
fn len(&self) -> usize {
0
}
fn is_empty(&self) -> bool {
false
}
fn list_cids(&self) -> Result<Vec<Cid>> {
Ok(Vec::new())
}
async fn put_many(&self, blocks: &[Block]) -> Result<()> {
if blocks.is_empty() {
return Ok(());
}
let semaphore = Arc::new(Semaphore::new(self.config.max_concurrent));
let mut tasks = Vec::with_capacity(blocks.len());
for block in blocks {
let permit = semaphore
.clone()
.acquire_owned()
.await
.map_err(|e| Error::Storage(format!("Failed to acquire semaphore: {e}")))?;
let block = block.clone();
let store = self.clone();
tasks.push(tokio::spawn(async move {
let result = store.put(&block).await;
drop(permit); result
}));
}
let results = join_all(tasks).await;
for result in results {
match result {
Ok(Ok(())) => {}
Ok(Err(e)) => return Err(e),
Err(e) => return Err(Error::Storage(format!("Task join error: {e}"))),
}
}
Ok(())
}
async fn get_many(&self, cids: &[Cid]) -> Result<Vec<Option<Block>>> {
if cids.is_empty() {
return Ok(Vec::new());
}
let semaphore = Arc::new(Semaphore::new(self.config.max_concurrent));
let mut tasks = Vec::with_capacity(cids.len());
for cid in cids {
let permit = semaphore
.clone()
.acquire_owned()
.await
.map_err(|e| Error::Storage(format!("Failed to acquire semaphore: {e}")))?;
let cid = *cid;
let store = self.clone();
tasks.push(tokio::spawn(async move {
let result = store.get(&cid).await;
drop(permit); result
}));
}
let results = join_all(tasks).await;
let mut blocks = Vec::with_capacity(cids.len());
for result in results {
match result {
Ok(Ok(block)) => blocks.push(block),
Ok(Err(e)) => return Err(e),
Err(e) => return Err(Error::Storage(format!("Task join error: {e}"))),
}
}
Ok(blocks)
}
async fn has_many(&self, cids: &[Cid]) -> Result<Vec<bool>> {
if cids.is_empty() {
return Ok(Vec::new());
}
let semaphore = Arc::new(Semaphore::new(self.config.max_concurrent));
let mut tasks = Vec::with_capacity(cids.len());
for cid in cids {
let permit = semaphore
.clone()
.acquire_owned()
.await
.map_err(|e| Error::Storage(format!("Failed to acquire semaphore: {e}")))?;
let cid = *cid;
let store = self.clone();
tasks.push(tokio::spawn(async move {
let result = store.has(&cid).await;
drop(permit); result
}));
}
let results = join_all(tasks).await;
let mut exists_vec = Vec::with_capacity(cids.len());
for result in results {
match result {
Ok(Ok(exists)) => exists_vec.push(exists),
Ok(Err(e)) => return Err(e),
Err(e) => return Err(Error::Storage(format!("Task join error: {e}"))),
}
}
Ok(exists_vec)
}
async fn delete_many(&self, cids: &[Cid]) -> Result<()> {
if cids.is_empty() {
return Ok(());
}
const BATCH_SIZE: usize = 1000;
for chunk in cids.chunks(BATCH_SIZE) {
let mut objects = Vec::with_capacity(chunk.len());
for cid in chunk {
let key = self.config.build_key(cid);
objects.push(ObjectIdentifier::builder().key(key).build().map_err(|e| {
Error::Storage(format!("Failed to build object identifier: {e}"))
})?);
}
let delete = Delete::builder()
.set_objects(Some(objects))
.build()
.map_err(|e| Error::Storage(format!("Failed to build delete request: {e}")))?;
self.client
.delete_objects()
.bucket(&self.config.bucket)
.delete(delete)
.send()
.await
.map_err(|e| Error::Storage(format!("Failed to delete objects: {e}")))?;
}
Ok(())
}
async fn flush(&self) -> Result<()> {
Ok(())
}
}
impl S3BlockStore {
pub async fn list_cids_async(&self) -> Result<Vec<Cid>> {
let mut cids = Vec::new();
let mut continuation_token: Option<String> = None;
loop {
let mut request = self.client.list_objects_v2().bucket(&self.config.bucket);
if let Some(prefix) = &self.config.prefix {
request = request.prefix(prefix);
}
if let Some(token) = continuation_token {
request = request.continuation_token(token);
}
let output = request
.send()
.await
.map_err(|e| Error::Storage(format!("Failed to list S3 objects: {e}")))?;
if let Some(contents) = output.contents {
for object in contents {
if let Some(key) = object.key {
let cid_str = if let Some(prefix) = &self.config.prefix {
key.strip_prefix(prefix).unwrap_or(&key)
} else {
&key
};
if let Ok(cid) = cid_str.parse::<Cid>() {
cids.push(cid);
}
}
}
}
if output.is_truncated == Some(true) {
continuation_token = output.next_continuation_token;
} else {
break;
}
}
Ok(cids)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_s3_config_build_key() {
let config = S3Config::new("test-bucket".to_string());
let cid = "QmTest123".parse::<Cid>().unwrap_or_else(|_| {
use bytes::Bytes;
*Block::new(Bytes::from("test")).unwrap().cid()
});
let key = config.build_key(&cid);
assert_eq!(key, cid.to_string());
let config = config.with_prefix("ipfrs/blocks/".to_string());
let key = config.build_key(&cid);
assert_eq!(key, format!("ipfrs/blocks/{}", cid));
}
#[test]
fn test_s3_config_builder() {
let config = S3Config::new("test-bucket".to_string())
.with_prefix("ipfrs/".to_string())
.with_region("us-west-2".to_string())
.with_endpoint("http://localhost:9000".to_string())
.with_multipart_threshold(10 * 1024 * 1024);
assert_eq!(config.bucket, "test-bucket");
assert_eq!(config.prefix, Some("ipfrs/".to_string()));
assert_eq!(config.region, Some("us-west-2".to_string()));
assert_eq!(config.endpoint, Some("http://localhost:9000".to_string()));
assert_eq!(config.multipart_threshold, 10 * 1024 * 1024);
}
}