rust-integration-services 0.5.26

A modern, fast, and lightweight integration library written in Rust, designed for memory safety and stability.
Documentation
use std::{marker::PhantomData, sync::Arc};

use aws_config::{BehaviorVersion, Region, SdkConfig};
use aws_sdk_s3::{Client, config::{Credentials, SharedCredentialsProvider}, types::{CompletedMultipartUpload, CompletedPart}};
use bytes::{Bytes, BytesMut};
use tokio_util::io::ReaderStream;

use crate::{common::stream::ByteStream, s3::s3_client_config::S3ClientConfig};

pub struct NoBucket;
pub struct HasBucket;

pub struct GetObject;
pub struct PutObject;

pub struct S3Client<State> {
    client: Arc<Client>,
    bucket: Option<String>,
    key: Option<String>,
    _state: PhantomData<State>,
}

impl S3Client<NoBucket> {
    pub fn new(config: S3ClientConfig) -> Self {
        Self {
            client: Arc::new(Self::build_client(config)),
            bucket: None,
            key: None,
            _state: PhantomData
        }
    }

    fn build_client(config: S3ClientConfig) -> Client {
        let creds = Credentials::new(config.access_key.unwrap_or(String::new()), config.secret_key.unwrap_or(String::new()), None, None, "static");
        let provider = SharedCredentialsProvider::new(creds);
        let region = Region::new(config.region.unwrap_or(String::from("auto")));
        
        let sdk_config = SdkConfig::builder()
        .region(region)
        .credentials_provider(provider)
        .behavior_version(BehaviorVersion::latest())
        .endpoint_url(config.endpoint.as_str());

        Client::new(&sdk_config.build())
    }

    pub fn bucket(&self, bucket: impl Into<String>) -> S3Client<HasBucket> {
        S3Client {
            client: self.client.clone(),
            bucket: Some(bucket.into()),
            key: None,
            _state: PhantomData
        }
    }
}

impl S3Client<HasBucket> {
    pub fn get_object(&self, key: impl Into<String>) -> S3Client<GetObject> {
        S3Client {
            client: self.client.clone(),
            bucket: self.bucket.clone(),
            key: Some(key.into()),
            _state: PhantomData
        }
    }

    pub fn put_object(&self, key: impl Into<String>) -> S3Client<PutObject> {
        S3Client {
            client: self.client.clone(),
            bucket: self.bucket.clone(),
            key: Some(key.into()),
            _state: PhantomData
        }
    }

    pub async fn delete_object(&self, key: impl AsRef<str>) -> anyhow::Result<()> {
        let _result = self.client
        .delete_object()
        .bucket(self.bucket.as_ref().unwrap())
        .key(key.as_ref())
        .send()
        .await?;

        Ok(())
    }
}

impl S3Client<GetObject> {
    pub async fn as_bytes(&self) -> anyhow::Result<Bytes> {
        let result = self.client
            .get_object()
            .bucket(self.bucket.as_ref().unwrap())
            .key(self.key.as_ref().unwrap())
            .send()
            .await?;

        Ok(result.body.collect().await?.into_bytes())
    }

    pub async fn as_stream(&self) -> anyhow::Result<ByteStream> {
        let result = self.client
            .get_object()
            .bucket(self.bucket.as_ref().unwrap())
            .key(self.key.as_ref().unwrap())
            .send()
            .await?;

        let stream = ReaderStream::new(result.body.into_async_read());
        Ok(ByteStream::new(stream))
    }
}

impl S3Client<PutObject> {
    pub async fn from_bytes(&self, bytes: impl Into<Bytes>) -> anyhow::Result<()> {
        let bytes = bytes.into();
        let _result = self.client
            .put_object()
            .bucket(self.bucket.as_ref().unwrap())
            .key(self.key.as_ref().unwrap())
            .body(bytes.into())
            .send()
            .await?;

        Ok(())
    }

    pub async fn from_stream(&self, stream: ByteStream) -> anyhow::Result<()> {
        let bucket = self.bucket.as_ref().unwrap();
        let key = self.key.as_ref().unwrap();

        let create_res = self.client
            .create_multipart_upload()
            .bucket(bucket)
            .key(key)
            .send()
            .await?;
        
        let upload_id = create_res.upload_id().ok_or_else(|| anyhow::anyhow!("No upload ID"))?;
        let upload_result = self.multipart_upload(upload_id, stream).await;

        if let Err(err) = upload_result {
            let _result = self.client
            .abort_multipart_upload()
            .bucket(bucket)
            .key(key)
            .upload_id(upload_id)
            .send()
            .await;

            return Err(err);
        }

        Ok(())
    }

    async fn multipart_upload(&self, upload_id: &str, mut stream: ByteStream) -> anyhow::Result<()> {
        let bucket = self.bucket.as_ref().unwrap();
        let key = self.key.as_ref().unwrap();
        let min_part_size: usize = 5 * 1024 * 1024;
        let mut completed_parts = Vec::new();
        let mut part_number = 1;
        let mut buffer = BytesMut::with_capacity(min_part_size);

        while let Some(chunk) = stream.next().await {
            let chunk = chunk?;
            buffer.extend_from_slice(&chunk);

            if buffer.len() >= min_part_size {
                let part = self.upload_part(upload_id, part_number, buffer.split_off(0).into()).await?;
                completed_parts.push(part);
                part_number += 1;
            }
        }

        if !buffer.is_empty() {
            let part = self.upload_part(upload_id, part_number, buffer.into()).await?;
            completed_parts.push(part);
        }

        let completed_upload = CompletedMultipartUpload::builder()
            .set_parts(Some(completed_parts))
            .build();

        self.client
            .complete_multipart_upload()
            .bucket(bucket)
            .key(key)
            .upload_id(upload_id)
            .multipart_upload(completed_upload)
            .send()
            .await?;

        Ok(())
    }

    async fn upload_part(&self, upload_id: &str, part_number: i32, bytes: bytes::Bytes) -> anyhow::Result<CompletedPart> {
        let upload_part_res = self.client
            .upload_part()
            .bucket(self.bucket.as_ref().unwrap())
            .key(self.key.as_ref().unwrap())
            .upload_id(upload_id)
            .part_number(part_number)
            .body(bytes.into())
            .send()
            .await?;

        Ok(CompletedPart::builder().e_tag(upload_part_res.e_tag().unwrap_or_default()).part_number(part_number).build())
    }
}