use std::collections::HashMap;
use bytes::Bytes;
use md5::{Digest, Md5};
use sqlx::PgPool;
use crate::backend::StorageBackend;
use crate::error::Post3Error;
use crate::models::{
BucketInfo, BucketRow, CompleteMultipartUploadResult, CreateMultipartUploadResult,
GetObjectResult, HeadObjectResult, ListMultipartUploadsResult, ListObjectsResult,
ListPartsResult, MultipartUploadInfo, MultipartUploadRow, ObjectInfo, ObjectMeta,
PutObjectResult, UploadPartResult,
};
use crate::repositories::blocks::BlocksRepository;
use crate::repositories::buckets::BucketsRepository;
use crate::repositories::metadata::MetadataRepository;
use crate::repositories::multipart_metadata::MultipartMetadataRepository;
use crate::repositories::multipart_uploads::MultipartUploadsRepository;
use crate::repositories::objects::ObjectsRepository;
use crate::repositories::upload_parts::UploadPartsRepository;
pub const DEFAULT_BLOCK_SIZE: usize = 1024 * 1024;
#[derive(Clone)]
pub struct Store {
db: PgPool,
block_size: usize,
}
pub type PostgresBackend = Store;
impl Store {
pub fn new(db: PgPool) -> Self {
Self {
db,
block_size: DEFAULT_BLOCK_SIZE,
}
}
pub fn with_block_size(mut self, block_size: usize) -> Self {
self.block_size = block_size;
self
}
pub fn pool(&self) -> &PgPool {
&self.db
}
async fn require_bucket(&self, name: &str) -> Result<BucketRow, Post3Error> {
BucketsRepository::new(&self.db)
.get_by_name(name)
.await?
.ok_or_else(|| Post3Error::BucketNotFound(name.to_string()))
}
async fn require_upload(
&self,
upload_id: &str,
expected_bucket_id: uuid::Uuid,
expected_key: &str,
) -> Result<MultipartUploadRow, Post3Error> {
let upload = MultipartUploadsRepository::get_by_upload_id(&self.db, upload_id)
.await?
.ok_or_else(|| Post3Error::UploadNotFound(upload_id.to_string()))?;
if upload.bucket_id != expected_bucket_id || upload.key != expected_key {
return Err(Post3Error::UploadNotFound(upload_id.to_string()));
}
Ok(upload)
}
}
impl StorageBackend for Store {
async fn create_bucket(&self, name: &str) -> Result<BucketInfo, Post3Error> {
let row = BucketsRepository::new(&self.db).create(name).await?;
Ok(BucketInfo {
name: row.name,
created_at: row.created_at,
})
}
async fn head_bucket(&self, name: &str) -> Result<Option<BucketInfo>, Post3Error> {
Ok(BucketsRepository::new(&self.db)
.get_by_name(name)
.await?
.map(|row| BucketInfo {
name: row.name,
created_at: row.created_at,
}))
}
async fn delete_bucket(&self, name: &str) -> Result<(), Post3Error> {
BucketsRepository::new(&self.db).delete(name).await
}
async fn list_buckets(&self) -> Result<Vec<BucketInfo>, Post3Error> {
Ok(BucketsRepository::new(&self.db)
.list()
.await?
.into_iter()
.map(|row| BucketInfo {
name: row.name,
created_at: row.created_at,
})
.collect())
}
async fn put_object(
&self,
bucket: &str,
key: &str,
content_type: Option<&str>,
metadata: HashMap<String, String>,
body: Bytes,
) -> Result<PutObjectResult, Post3Error> {
let bucket_row = self.require_bucket(bucket).await?;
let content_type = content_type.unwrap_or("application/octet-stream");
let mut hasher = Md5::new();
hasher.update(&body);
let etag = format!("\"{}\"", hex::encode(hasher.finalize()));
let size = body.len() as i64;
let mut tx = self.db.begin().await?;
let object_row = ObjectsRepository::insert_in_tx(
&mut tx,
bucket_row.id,
key,
size,
&etag,
content_type,
)
.await?;
for (chunk_index, chunk) in body.chunks(self.block_size).enumerate() {
BlocksRepository::insert_in_tx(
&mut tx,
object_row.id,
chunk_index as i32,
chunk,
)
.await?;
}
if !metadata.is_empty() {
MetadataRepository::insert_batch_in_tx(
&mut tx,
object_row.id,
&metadata,
)
.await?;
}
tx.commit().await?;
Ok(PutObjectResult { etag, size })
}
async fn get_object(
&self,
bucket: &str,
key: &str,
) -> Result<GetObjectResult, Post3Error> {
let bucket_row = self.require_bucket(bucket).await?;
let object = ObjectsRepository::new(&self.db)
.get(bucket_row.id, key)
.await?
.ok_or_else(|| Post3Error::ObjectNotFound {
bucket: bucket.to_string(),
key: key.to_string(),
})?;
let blocks = BlocksRepository::get_all(&self.db, object.id).await?;
let mut body = Vec::with_capacity(object.size as usize);
for block in blocks {
body.extend_from_slice(&block.data);
}
let user_metadata =
MetadataRepository::get_all(&self.db, object.id).await?;
Ok(GetObjectResult {
metadata: ObjectMeta {
key: object.key,
size: object.size,
etag: object.etag,
content_type: object.content_type,
last_modified: object.created_at,
},
user_metadata,
body: Bytes::from(body),
})
}
async fn head_object(
&self,
bucket: &str,
key: &str,
) -> Result<Option<HeadObjectResult>, Post3Error> {
let bucket_row = self.require_bucket(bucket).await?;
let object = ObjectsRepository::new(&self.db)
.get(bucket_row.id, key)
.await?;
match object {
Some(obj) => {
let user_metadata =
MetadataRepository::get_all(&self.db, obj.id).await?;
Ok(Some(HeadObjectResult {
object: ObjectMeta {
key: obj.key,
size: obj.size,
etag: obj.etag,
content_type: obj.content_type,
last_modified: obj.created_at,
},
user_metadata,
}))
}
None => Ok(None),
}
}
async fn delete_object(
&self,
bucket: &str,
key: &str,
) -> Result<(), Post3Error> {
let bucket_row = self.require_bucket(bucket).await?;
ObjectsRepository::new(&self.db)
.delete(bucket_row.id, key)
.await?;
Ok(())
}
async fn list_objects_v2(
&self,
bucket: &str,
prefix: Option<&str>,
continuation_token: Option<&str>,
max_keys: Option<i64>,
delimiter: Option<&str>,
) -> Result<ListObjectsResult, Post3Error> {
let bucket_row = self.require_bucket(bucket).await?;
let max_keys = max_keys.unwrap_or(1000);
if max_keys == 0 {
return Ok(ListObjectsResult {
objects: Vec::new(),
is_truncated: false,
next_continuation_token: None,
prefix: prefix.map(|s| s.to_string()),
delimiter: delimiter.map(|s| s.to_string()),
common_prefixes: Vec::new(),
key_count: 0,
});
}
let fetch_limit = if delimiter.is_some() {
(max_keys + 1) * 10
} else {
max_keys + 1
};
let rows = ObjectsRepository::new(&self.db)
.list(bucket_row.id, prefix, continuation_token, fetch_limit)
.await?;
let all_objects: Vec<ObjectInfo> = rows
.into_iter()
.map(|o| ObjectInfo {
key: o.key,
size: o.size,
etag: o.etag,
last_modified: o.created_at,
})
.collect();
let prefix_str = prefix.unwrap_or("");
if let Some(delim) = delimiter {
let mut seen_prefixes = std::collections::BTreeSet::new();
let mut direct_objects = Vec::new();
for obj in &all_objects {
let after_prefix = &obj.key[prefix_str.len()..];
if let Some(pos) = after_prefix.find(delim) {
let cp = format!("{}{}", prefix_str, &after_prefix[..pos + delim.len()]);
seen_prefixes.insert(cp);
} else {
direct_objects.push(obj.clone());
}
}
let all_prefixes: Vec<String> = if let Some(token) = continuation_token {
seen_prefixes
.into_iter()
.filter(|cp| cp.as_str() > token)
.collect()
} else {
seen_prefixes.into_iter().collect()
};
let mut result_objects = Vec::new();
let mut result_prefixes = Vec::new();
let mut oi = 0usize;
let mut pi = 0usize;
let mut count = 0i64;
let mut last_key: Option<String> = None;
while count < max_keys && (oi < direct_objects.len() || pi < all_prefixes.len()) {
let take_object = match (direct_objects.get(oi), all_prefixes.get(pi)) {
(Some(obj), Some(pfx)) => obj.key.as_str() < pfx.as_str(),
(Some(_), None) => true,
(None, Some(_)) => false,
(None, None) => break,
};
if take_object {
last_key = Some(direct_objects[oi].key.clone());
result_objects.push(direct_objects[oi].clone());
oi += 1;
} else {
last_key = Some(all_prefixes[pi].clone());
result_prefixes.push(all_prefixes[pi].clone());
pi += 1;
}
count += 1;
}
let is_truncated = oi < direct_objects.len() || pi < all_prefixes.len();
let next_token = if is_truncated { last_key } else { None };
let key_count = result_objects.len() + result_prefixes.len();
Ok(ListObjectsResult {
objects: result_objects,
is_truncated,
next_continuation_token: next_token,
prefix: prefix.map(|s| s.to_string()),
delimiter: Some(delim.to_string()),
common_prefixes: result_prefixes,
key_count,
})
} else {
let is_truncated = all_objects.len() as i64 > max_keys;
let items: Vec<_> = all_objects.into_iter().take(max_keys as usize).collect();
let next_token = if is_truncated {
items.last().map(|o| o.key.clone())
} else {
None
};
let key_count = items.len();
Ok(ListObjectsResult {
objects: items,
is_truncated,
next_continuation_token: next_token,
prefix: prefix.map(|s| s.to_string()),
delimiter: None,
common_prefixes: Vec::new(),
key_count,
})
}
}
async fn create_multipart_upload(
&self,
bucket: &str,
key: &str,
content_type: Option<&str>,
metadata: HashMap<String, String>,
) -> Result<CreateMultipartUploadResult, Post3Error> {
let bucket_row = self.require_bucket(bucket).await?;
let content_type = content_type.unwrap_or("application/octet-stream");
let upload_id = uuid::Uuid::new_v4().to_string();
let mut tx = self.db.begin().await?;
let upload_row = MultipartUploadsRepository::create_in_tx(
&mut tx,
bucket_row.id,
key,
&upload_id,
content_type,
)
.await?;
if !metadata.is_empty() {
MultipartMetadataRepository::insert_batch_in_tx(
&mut tx,
upload_row.id,
&metadata,
)
.await?;
}
tx.commit().await?;
Ok(CreateMultipartUploadResult {
bucket: bucket.to_string(),
key: key.to_string(),
upload_id,
})
}
async fn upload_part(
&self,
bucket: &str,
key: &str,
upload_id: &str,
part_number: i32,
body: Bytes,
) -> Result<UploadPartResult, Post3Error> {
let bucket_row = self.require_bucket(bucket).await?;
let upload = self
.require_upload(upload_id, bucket_row.id, key)
.await?;
let mut hasher = Md5::new();
hasher.update(&body);
let etag = format!("\"{}\"", hex::encode(hasher.finalize()));
let size = body.len() as i64;
UploadPartsRepository::upsert(
&self.db,
upload.id,
part_number,
&body,
size,
&etag,
)
.await?;
Ok(UploadPartResult { etag })
}
async fn complete_multipart_upload(
&self,
bucket: &str,
key: &str,
upload_id: &str,
part_etags: Vec<(i32, String)>,
) -> Result<CompleteMultipartUploadResult, Post3Error> {
let bucket_row = self.require_bucket(bucket).await?;
let upload = self
.require_upload(upload_id, bucket_row.id, key)
.await?;
for window in part_etags.windows(2) {
if window[0].0 >= window[1].0 {
return Err(Post3Error::InvalidPartOrder);
}
}
let part_numbers: Vec<i32> = part_etags.iter().map(|(n, _)| *n).collect();
let parts = UploadPartsRepository::get_ordered_by_numbers(
&self.db,
upload.id,
&part_numbers,
)
.await?;
for (expected_num, expected_etag) in &part_etags {
let part = parts
.iter()
.find(|p| p.part_number == *expected_num)
.ok_or_else(|| Post3Error::InvalidPart {
upload_id: upload_id.to_string(),
part_number: *expected_num,
})?;
let stored = part.etag.trim_matches('"');
let expected = expected_etag.trim_matches('"');
if stored != expected {
return Err(Post3Error::ETagMismatch {
part_number: *expected_num,
expected: expected_etag.clone(),
got: part.etag.clone(),
});
}
}
const MIN_PART_SIZE: i64 = 5 * 1024 * 1024;
for (i, part) in parts.iter().enumerate() {
if i < parts.len() - 1 && part.size < MIN_PART_SIZE {
return Err(Post3Error::EntityTooSmall {
part_number: part.part_number,
size: part.size,
});
}
}
let mut etag_hasher = Md5::new();
let part_count = parts.len();
for part in &parts {
let hex_str = part.etag.trim_matches('"');
if let Ok(raw_md5) = hex::decode(hex_str) {
etag_hasher.update(&raw_md5);
}
}
let compound_etag = format!(
"\"{}-{}\"",
hex::encode(etag_hasher.finalize()),
part_count
);
let total_size: i64 = parts.iter().map(|p| p.size).sum();
let mut assembled = Vec::with_capacity(total_size as usize);
for part in &parts {
assembled.extend_from_slice(&part.data);
}
let user_metadata =
MultipartMetadataRepository::get_all(&self.db, upload.id).await?;
let mut tx = self.db.begin().await?;
let object_row = ObjectsRepository::insert_in_tx(
&mut tx,
bucket_row.id,
key,
total_size,
&compound_etag,
&upload.content_type,
)
.await?;
for (chunk_index, chunk) in assembled.chunks(self.block_size).enumerate() {
BlocksRepository::insert_in_tx(
&mut tx,
object_row.id,
chunk_index as i32,
chunk,
)
.await?;
}
if !user_metadata.is_empty() {
MetadataRepository::insert_batch_in_tx(
&mut tx,
object_row.id,
&user_metadata,
)
.await?;
}
MultipartUploadsRepository::delete_in_tx(&mut tx, upload.id).await?;
tx.commit().await?;
Ok(CompleteMultipartUploadResult {
bucket: bucket.to_string(),
key: key.to_string(),
etag: compound_etag,
size: total_size,
})
}
async fn abort_multipart_upload(
&self,
bucket: &str,
key: &str,
upload_id: &str,
) -> Result<(), Post3Error> {
let bucket_row = self.require_bucket(bucket).await?;
let upload = self
.require_upload(upload_id, bucket_row.id, key)
.await?;
MultipartUploadsRepository::delete_by_upload_id(&self.db, &upload.upload_id)
.await?;
Ok(())
}
async fn list_parts(
&self,
bucket: &str,
key: &str,
upload_id: &str,
max_parts: Option<i32>,
part_number_marker: Option<i32>,
) -> Result<ListPartsResult, Post3Error> {
let bucket_row = self.require_bucket(bucket).await?;
let upload = self
.require_upload(upload_id, bucket_row.id, key)
.await?;
let max_parts = max_parts.unwrap_or(1000) as i64;
let parts = UploadPartsRepository::list_info(
&self.db,
upload.id,
part_number_marker,
max_parts + 1,
)
.await?;
let is_truncated = parts.len() as i64 > max_parts;
let items: Vec<_> = parts.into_iter().take(max_parts as usize).collect();
let next_marker = if is_truncated {
items.last().map(|p| p.part_number)
} else {
None
};
Ok(ListPartsResult {
bucket: bucket.to_string(),
key: key.to_string(),
upload_id: upload_id.to_string(),
parts: items,
is_truncated,
next_part_number_marker: next_marker,
})
}
async fn list_multipart_uploads(
&self,
bucket: &str,
prefix: Option<&str>,
key_marker: Option<&str>,
upload_id_marker: Option<&str>,
max_uploads: Option<i32>,
) -> Result<ListMultipartUploadsResult, Post3Error> {
let bucket_row = self.require_bucket(bucket).await?;
let max_uploads = max_uploads.unwrap_or(1000) as i64;
let rows = MultipartUploadsRepository::list(
&self.db,
bucket_row.id,
prefix,
key_marker,
upload_id_marker,
max_uploads + 1,
)
.await?;
let is_truncated = rows.len() as i64 > max_uploads;
let items: Vec<_> = rows.into_iter().take(max_uploads as usize).collect();
let (next_key_marker, next_upload_id_marker) = if is_truncated {
items
.last()
.map(|u| (Some(u.key.clone()), Some(u.upload_id.clone())))
.unwrap_or((None, None))
} else {
(None, None)
};
let uploads = items
.into_iter()
.map(|u| MultipartUploadInfo {
key: u.key,
upload_id: u.upload_id,
initiated: u.created_at,
})
.collect();
Ok(ListMultipartUploadsResult {
bucket: bucket.to_string(),
uploads,
is_truncated,
next_key_marker,
next_upload_id_marker,
prefix: prefix.map(|s| s.to_string()),
})
}
}