use crate::{
StreamData,
endpoint::S3Endpoint,
error::Error,
writer::{StreamWriter, WriteJob},
};
use async_std::channel::Receiver;
use async_trait::async_trait;
use aws_sdk_s3::primitives::ByteStream;
use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart};
use std::{
sync::mpsc::{Sender, channel},
thread,
time::Duration,
};
use threadpool::ThreadPool;
impl S3Endpoint {
pub async fn start_multi_part_s3_upload(&self, path: &str) -> Result<String, Error> {
let client = self.connection();
let bucket = self.bucket().to_string();
let key = path.to_string().clone();
let object = client
.create_multipart_upload()
.bucket(bucket)
.key(key)
.send()
.await?;
object
.upload_id
.ok_or_else(|| Error::Other("Cannot retrieve upload ID from object".to_string()))
}
pub async fn upload_s3_part(
&self,
path: &str,
upload_id: &str,
part_number: i32,
data: Vec<u8>,
) -> Result<CompletedPart, Error> {
let bucket = self.bucket().to_string();
let key = path.to_string();
let cloned_upload_id = upload_id.to_string();
let client = self.connection();
let object = client
.upload_part()
.body(ByteStream::from(data))
.bucket(bucket)
.key(key)
.upload_id(cloned_upload_id)
.part_number(part_number)
.send()
.await?;
Ok(
CompletedPart::builder()
.set_e_tag(object.e_tag)
.set_part_number(Some(part_number))
.build(),
)
}
async fn upload_s3_part_and_send(
&self,
cloned_path: &str,
upload_identifier: &str,
part_number: i32,
part_buffer: Vec<u8>,
part_sender: Sender<CompletedPart>,
) -> Result<(), Error> {
let writer = self.clone();
let path = cloned_path.to_string();
let upload_identifier = upload_identifier.to_string();
writer
.upload_s3_part(&path, &upload_identifier, part_number, part_buffer)
.await
.and_then(|part_id| part_sender.send(part_id).map_err(|e| e.into()))
}
pub async fn complete_s3_upload(
&self,
path: &str,
upload_id: &str,
parts: Vec<CompletedPart>,
) -> Result<(), Error> {
let bucket = self.bucket().to_string();
let key = path.to_string();
let cloned_upload_id = upload_id.to_string();
let multipart_upload = CompletedMultipartUpload::builder()
.set_parts(Some(parts))
.build();
let client = self.connection();
let _response = client
.complete_multipart_upload()
.bucket(bucket)
.key(key)
.upload_id(cloned_upload_id)
.multipart_upload(multipart_upload)
.send()
.await?;
Ok(())
}
}
#[async_trait]
impl StreamWriter for S3Endpoint {
async fn write_stream(
&self,
path: &str,
receiver: Receiver<StreamData>,
job_and_notification: &dyn WriteJob,
) -> Result<(), Error> {
let upload_identifier = self.start_multi_part_s3_upload(path).await?;
let mut part_number = 1;
let part_size = std::env::var("S3_WRITER_PART_SIZE")
.map(|buffer_size| buffer_size.parse::<usize>())
.unwrap_or_else(|_| Ok(10 * 1024 * 1024))
.unwrap_or(10 * 1024 * 1024);
let mut part_buffer: Vec<u8> = Vec::with_capacity(part_size);
let n_workers = std::env::var("S3_WRITER_WORKERS")
.map(|buffer_size| buffer_size.parse::<usize>())
.unwrap_or_else(|_| Ok(4))
.unwrap_or(4);
let mut n_jobs = 0;
let pool = ThreadPool::new(n_workers);
let mut file_size = None;
let mut received_bytes = 0;
let mut prev_percent = 0;
let (part_sender, part_receiver) = channel();
while let Ok(mut stream_data) = receiver.recv().await {
match stream_data {
StreamData::Size(size) => file_size = Some(size),
StreamData::Stop => break,
StreamData::Eof => {
n_jobs += 1;
self
.upload_s3_part_and_send(
path,
&upload_identifier,
part_number,
part_buffer.clone(),
part_sender.clone(),
)
.await?;
let mut complete_parts = part_receiver
.iter()
.take(n_jobs)
.collect::<Vec<CompletedPart>>();
complete_parts.sort_by(|part1, part2| part1.part_number.cmp(&part2.part_number));
self
.complete_s3_upload(path, &upload_identifier, complete_parts)
.await?;
break;
}
StreamData::Data(ref mut data) => {
received_bytes += data.len();
if let Some(file_size) = file_size {
let percent = (received_bytes as f32 / file_size as f32 * 100.0) as u8;
if percent > prev_percent {
prev_percent = percent;
job_and_notification.progress(percent)?;
}
}
part_buffer.append(data);
if part_buffer.len() > part_size {
while pool.queued_count() > 1 {
thread::sleep(Duration::from_millis(500));
}
self
.upload_s3_part_and_send(
path,
&upload_identifier,
part_number,
part_buffer.clone(),
part_sender.clone(),
)
.await?;
n_jobs += 1;
part_number += 1;
part_buffer.clear();
}
}
}
}
Ok(())
}
}