rhombus 0.2.21

Next generation extendable CTF framework with batteries included
Documentation
use std::{io, str::FromStr, sync::Arc};

use async_hash::{Digest, Sha256};
use axum::{body::Bytes, routing::post, Router};
use futures::{Stream, TryStreamExt};

use s3::{creds::Credentials, Bucket};
use tokio::io::AsyncReadExt;
use tokio_util::{bytes::BytesMut, io::StreamReader};

use crate::{
    internal::{
        local_upload_provider::slice_to_hex_string, settings::S3UploadProviderSettings,
        upload_provider::route_upload_file,
    },
    upload_provider::UploadProvider,
    Result,
};

#[derive(Clone)]
pub struct S3UploadProvider {
    pub bucket: Box<Bucket>,
    pub presigned_get_expiry: Option<u32>,
    pub prefix: String,
}

impl S3UploadProvider {
    pub async fn new(s3: &S3UploadProviderSettings) -> Result<S3UploadProvider> {
        let credentials = Credentials::new(
            s3.access_key.as_deref(),
            s3.secret_key.as_deref(),
            s3.security_token.as_deref(),
            s3.session_token.as_deref(),
            s3.profile.as_deref(),
        )?;
        let region = if let Some(endpoint) = &s3.endpoint {
            s3::Region::Custom {
                region: s3
                    .bucket_region
                    .as_ref()
                    .cloned()
                    .unwrap_or("us-east-1".to_owned()),
                endpoint: endpoint.clone(),
            }
        } else {
            s3::Region::from_str(s3.bucket_region.as_deref().unwrap_or("us-east-1")).unwrap()
        };

        let bucket = Bucket::new(&s3.bucket_name, region, credentials)?;

        let bucket = if s3.path_style.unwrap_or(false) {
            bucket.with_path_style()
        } else {
            bucket
        };

        let exists = bucket.exists().await?;
        if !exists {
            tracing::error!("configured bucket does not exist");
        }

        Ok(S3UploadProvider {
            bucket,
            presigned_get_expiry: s3.presigned_get_expiry,
            prefix: s3.prefix.clone().unwrap_or("".to_owned()),
        })
    }
}

impl UploadProvider for S3UploadProvider {
    fn routes(&self) -> Result<Router> {
        let provider_state = Arc::new(self.clone());
        let router = Router::new()
            .route("/upload/:path", post(route_upload_file::<Self>))
            .with_state(provider_state);
        Ok(router)
    }

    async fn upload<S, E>(&self, filename: &str, stream: S) -> Result<String>
    where
        S: Stream<Item = std::result::Result<Bytes, E>> + Send,
        E: Into<axum::BoxError>,
    {
        let body_with_io_error = stream.map_err(|err| io::Error::new(io::ErrorKind::Other, err));
        let body_reader = StreamReader::new(body_with_io_error);

        futures::pin_mut!(body_reader);

        let mut buffer = BytesMut::new();
        while let Ok(bytes_read) = body_reader.read_buf(&mut buffer).await {
            if bytes_read == 0 {
                break;
            }
        }

        let mut hasher = Sha256::new();
        hasher.update(&buffer);
        let hash = slice_to_hex_string(hasher.finalize().as_slice());

        let s3_path = format!("{}{}/{}", self.prefix, hash, filename);

        _ = self.bucket.put_object(&s3_path, &buffer).await?;

        let url = if let Some(presigned_get_expiry) = self.presigned_get_expiry {
            self.bucket
                .presign_get(&s3_path, presigned_get_expiry, None)
                .await?
        } else {
            format!("{}/{}", self.bucket.url(), &s3_path)
        };

        tracing::info!(url, "uploaded to s3");

        Ok(url)
    }
}