use std::sync::Arc;
use async_channel::{Receiver, Sender};
use aws_sdk_s3::model::CompletedPart;
use futures::Stream;
use futures_util::StreamExt;
use s3::s3backend::S3Backend;
use tokio::try_join;
mod s3;
pub const UPLOAD_CHUNK_SIZE: u64 = 104_857_600;
pub async fn upload_file(
url: String,
bucket: String,
key: String,
chunk_size: Option<u64>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let s3backend = Arc::new(
S3Backend::new()
.await
.expect("Error in initializing s3backend"),
);
let resp = reqwest::get(url).await?;
let cont_length = resp
.content_length()
.clone()
.expect("ContentLength is needed!");
let mut data_stream = resp.bytes_stream();
let chunk_size = match chunk_size {
None => UPLOAD_CHUNK_SIZE,
Some(size) => size,
};
if cont_length > chunk_size {
let number_of_parts = cont_length / chunk_size;
let last_part = cont_length % chunk_size;
let upload_id = s3backend
.init_multipart_upload(bucket.to_string(), key.to_string())
.await?;
let mut accumulator: usize = 0;
let mut part_number: i32 = 1;
let (mut chan_send, mut chan_recv) = async_channel::bounded(30);
let mut queue = Vec::new();
queue.push(spawn_multi_upload(
s3backend.clone(),
bucket.to_string(),
key.to_string(),
upload_id.to_string(),
chan_recv.clone(),
part_number,
chunk_size as i64,
));
let mut next_bytes;
while let Some(chunk) = data_stream.next().await {
let mut ch = chunk?;
let length = ch.len();
if accumulator + length > chunk_size as usize {
let max_size = length + accumulator - chunk_size as usize;
next_bytes = ch.split_to(max_size as usize);
chan_send.send(Ok(ch)).await?;
(chan_send, chan_recv) = async_channel::bounded(30);
part_number += 1;
let size = if part_number == number_of_parts as i32 + 1 {
last_part
} else {
chunk_size
};
queue.push(spawn_multi_upload(
s3backend.clone(),
bucket.to_string(),
key.to_string(),
upload_id.to_string(),
chan_recv.clone(),
part_number,
size as i64,
));
accumulator = next_bytes.len();
chan_send.send(Ok(next_bytes)).await?;
} else {
chan_send.send(Ok(ch)).await?;
accumulator += length;
}
}
let mut completed_parts = Vec::new();
for x in queue {
let waited_for = x.await??;
completed_parts.push(
CompletedPart::builder()
.e_tag(waited_for.1)
.part_number(waited_for.0)
.build(),
);
}
s3backend
.finish_multipart_upload(
bucket.to_string(),
key.to_string(),
completed_parts,
upload_id,
)
.await?;
} else {
let (chan_send, chan_recv) = async_channel::bounded(30);
let single_uploader = s3backend.upload_object(chan_recv, bucket, key, cont_length as i64);
let pro_chunks = process_chunks(data_stream, chan_send);
if let Err(err) = try_join!(single_uploader, pro_chunks) {
log::error!("{}", err);
return Ok(());
}
}
Ok(())
}
async fn process_chunks(
mut data_stream: impl Stream<Item = Result<bytes::Bytes, reqwest::Error>> + std::marker::Unpin,
chan_send: Sender<Result<bytes::Bytes, reqwest::Error>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
while let Some(chunk) = data_stream.next().await {
chan_send.send(chunk).await?;
}
Ok(())
}
fn spawn_multi_upload(
backend: Arc<S3Backend>,
bucket: String,
key: String,
upload_id: String,
recv_chan: Receiver<Result<bytes::Bytes, reqwest::Error>>,
part_number: i32,
content_len: i64,
) -> tokio::task::JoinHandle<
Result<(i32, String), Box<dyn std::error::Error + Sync + std::marker::Send>>,
> {
tokio::spawn(async move {
backend
.upload_multi_object(recv_chan, bucket, key, upload_id, content_len, part_number)
.await
})
}