use async_trait::async_trait;
use futures::{stream::FuturesUnordered, Future, StreamExt};
use std::{io, pin::Pin, sync::Arc, task::Poll};
use tokio::io::AsyncWrite;
use crate::Result;
type BoxedTryFuture<T> = Pin<Box<dyn Future<Output = Result<T, io::Error>> + Send>>;
#[async_trait]
pub(crate) trait CloudMultiPartUploadImpl: 'static {
async fn put_multipart_part(
&self,
buf: Vec<u8>,
part_idx: usize,
) -> Result<UploadPart, io::Error>;
async fn complete(&self, completed_parts: Vec<UploadPart>) -> Result<(), io::Error>;
}
#[derive(Debug, Clone)]
pub(crate) struct UploadPart {
pub content_id: String,
}
pub(crate) struct CloudMultiPartUpload<T>
where
T: CloudMultiPartUploadImpl,
{
inner: Arc<T>,
completed_parts: Vec<Option<UploadPart>>,
tasks: FuturesUnordered<BoxedTryFuture<(usize, UploadPart)>>,
max_concurrency: usize,
current_buffer: Vec<u8>,
min_part_size: usize,
current_part_idx: usize,
completion_task: Option<BoxedTryFuture<()>>,
}
impl<T> CloudMultiPartUpload<T>
where
T: CloudMultiPartUploadImpl,
{
pub fn new(inner: T, max_concurrency: usize) -> Self {
Self {
inner: Arc::new(inner),
completed_parts: Vec::new(),
tasks: FuturesUnordered::new(),
max_concurrency,
current_buffer: Vec::new(),
min_part_size: 5_242_880,
current_part_idx: 0,
completion_task: None,
}
}
pub fn poll_tasks(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Result<(), io::Error> {
if self.tasks.is_empty() {
return Ok(());
}
while let Poll::Ready(Some(res)) = self.tasks.poll_next_unpin(cx) {
let (part_idx, part) = res?;
let total_parts = self.completed_parts.len();
self.completed_parts
.resize(std::cmp::max(part_idx + 1, total_parts), None);
self.completed_parts[part_idx] = Some(part);
}
Ok(())
}
}
impl<T> CloudMultiPartUpload<T>
where
T: CloudMultiPartUploadImpl + Send + Sync,
{
fn final_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
self.as_mut().poll_tasks(cx)?;
if !self.current_buffer.is_empty() && self.tasks.len() < self.max_concurrency {
let out_buffer: Vec<u8> = std::mem::take(&mut self.current_buffer);
let inner = Arc::clone(&self.inner);
let part_idx = self.current_part_idx;
self.tasks.push(Box::pin(async move {
let upload_part = inner.put_multipart_part(out_buffer, part_idx).await?;
Ok((part_idx, upload_part))
}));
}
self.as_mut().poll_tasks(cx)?;
if self.tasks.is_empty() && self.current_buffer.is_empty() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
impl<T> AsyncWrite for CloudMultiPartUpload<T>
where
T: CloudMultiPartUploadImpl + Send + Sync,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.as_mut().poll_tasks(cx)?;
let enough_to_send =
(buf.len() + self.current_buffer.len()) >= self.min_part_size;
if enough_to_send && self.tasks.len() < self.max_concurrency {
self.current_buffer.extend_from_slice(buf);
let out_buffer = std::mem::take(&mut self.current_buffer);
let inner = Arc::clone(&self.inner);
let part_idx = self.current_part_idx;
self.tasks.push(Box::pin(async move {
let upload_part = inner.put_multipart_part(out_buffer, part_idx).await?;
Ok((part_idx, upload_part))
}));
self.current_part_idx += 1;
self.as_mut().poll_tasks(cx)?;
Poll::Ready(Ok(buf.len()))
} else if !enough_to_send {
self.current_buffer.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
} else {
Poll::Pending
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
self.as_mut().poll_tasks(cx)?;
if self.tasks.is_empty() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
match self.as_mut().final_flush(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(res) => res?,
};
let parts = std::mem::take(&mut self.completed_parts);
let parts = parts
.into_iter()
.enumerate()
.map(|(idx, part)| {
part.ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
format!("Missing information for upload part {idx}"),
)
})
})
.collect::<Result<_, _>>()?;
let inner = Arc::clone(&self.inner);
let completion_task = self.completion_task.get_or_insert_with(|| {
Box::pin(async move {
inner.complete(parts).await?;
Ok(())
})
});
Pin::new(completion_task).poll(cx)
}
}