use std::time::Duration;
use axum::body::Body;
use axum::http::{StatusCode, header};
use axum::response::{IntoResponse, Response};
use bytes::Bytes;
use rusty_gasket::BoxError;
#[derive(Clone)]
pub struct S3ObjectStore {
client: aws_sdk_s3::Client,
bucket: String,
}
impl std::fmt::Debug for S3ObjectStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("S3ObjectStore")
.field("bucket", &self.bucket)
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ObjectMeta {
pub content_length: Option<u64>,
pub content_type: Option<String>,
pub e_tag: Option<String>,
}
impl S3ObjectStore {
pub fn new(client: aws_sdk_s3::Client, bucket: impl Into<String>) -> Self {
Self {
client,
bucket: bucket.into(),
}
}
pub async fn from_env(bucket: impl Into<String>) -> Self {
let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
Self::new(aws_sdk_s3::Client::new(&config), bucket)
}
#[must_use]
pub fn bucket(&self) -> &str {
&self.bucket
}
pub async fn get(&self, key: &str) -> Result<Bytes, BoxError> {
let output = self
.client
.get_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|e| format!("S3 get_object {}/{key} failed: {e}", self.bucket))?;
let data = output
.body
.collect()
.await
.map_err(|e| format!("S3 read body {}/{key} failed: {e}", self.bucket))?;
Ok(data.into_bytes())
}
pub async fn put(
&self,
key: &str,
body: Bytes,
content_type: Option<&str>,
) -> Result<(), BoxError> {
let mut request = self
.client
.put_object()
.bucket(&self.bucket)
.key(key)
.body(body.into());
if let Some(ct) = content_type {
request = request.content_type(ct);
}
request
.send()
.await
.map_err(|e| format!("S3 put_object {}/{key} failed: {e}", self.bucket))?;
Ok(())
}
pub async fn head(&self, key: &str) -> Result<Option<ObjectMeta>, BoxError> {
let result = self
.client
.head_object()
.bucket(&self.bucket)
.key(key)
.send()
.await;
match result {
Ok(output) => Ok(Some(ObjectMeta {
content_length: output.content_length().and_then(|n| u64::try_from(n).ok()),
content_type: output.content_type().map(str::to_owned),
e_tag: output.e_tag().map(str::to_owned),
})),
Err(error) => {
if error
.as_service_error()
.is_some_and(aws_sdk_s3::operation::head_object::HeadObjectError::is_not_found)
{
Ok(None)
} else {
Err(format!("S3 head_object {}/{key} failed: {error}", self.bucket).into())
}
}
}
}
pub async fn list(&self, prefix: &str) -> Result<Vec<String>, BoxError> {
let mut keys = Vec::new();
let mut continuation: Option<String> = None;
loop {
let mut request = self
.client
.list_objects_v2()
.bucket(&self.bucket)
.prefix(prefix);
if let Some(token) = &continuation {
request = request.continuation_token(token);
}
let output = request
.send()
.await
.map_err(|e| format!("S3 list_objects_v2 {}/{prefix} failed: {e}", self.bucket))?;
for object in output.contents() {
if let Some(key) = object.key() {
keys.push(key.to_owned());
}
}
if output.is_truncated().unwrap_or(false) {
continuation = output.next_continuation_token().map(str::to_owned);
if continuation.is_none() {
break;
}
} else {
break;
}
}
Ok(keys)
}
pub async fn presigned_get(&self, key: &str, expires_in: Duration) -> Result<String, BoxError> {
let presigning = aws_sdk_s3::presigning::PresigningConfig::expires_in(expires_in)
.map_err(|e| format!("S3 presign config invalid: {e}"))?;
let presigned = self
.client
.get_object()
.bucket(&self.bucket)
.key(key)
.presigned(presigning)
.await
.map_err(|e| format!("S3 presign {}/{key} failed: {e}", self.bucket))?;
Ok(presigned.uri().to_owned())
}
pub async fn download_response(&self, key: &str) -> Response {
let output = match self
.client
.get_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
{
Ok(output) => output,
Err(error) => {
if error
.as_service_error()
.is_some_and(aws_sdk_s3::operation::get_object::GetObjectError::is_no_such_key)
{
return (StatusCode::NOT_FOUND, "not found").into_response();
}
tracing::warn!(bucket = %self.bucket, key, %error, "S3 download failed");
return (StatusCode::BAD_GATEWAY, "upstream storage error").into_response();
}
};
let content_type = output
.content_type()
.unwrap_or("application/octet-stream")
.to_owned();
let content_length = output.content_length();
let stream = tokio_util::io::ReaderStream::new(output.body.into_async_read());
let mut builder = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, content_type);
if let Some(len) = content_length {
builder = builder.header(header::CONTENT_LENGTH, len);
}
match builder.body(Body::from_stream(stream)) {
Ok(response) => response,
Err(error) => {
tracing::error!(%error, "failed to build S3 download response");
(StatusCode::INTERNAL_SERVER_ERROR, "response build error").into_response()
}
}
}
}