use futures::{future::BoxFuture, stream::FuturesUnordered, Future, StreamExt};
use std::{io, pin::Pin, sync::Arc, task::Poll};
use tokio::io::AsyncWrite;
use crate::Result;
type BoxedTryFuture<T> = Pin<Box<dyn Future<Output = Result<T, io::Error>> + Send>>;
pub(crate) trait CloudMultiPartUploadImpl {
fn put_multipart_part(
&self,
buf: Vec<u8>,
part_idx: usize,
) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>>;
fn complete(
&self,
completed_parts: Vec<Option<UploadPart>>,
) -> BoxFuture<'static, Result<(), io::Error>>;
}
#[derive(Debug, Clone)]
pub(crate) struct UploadPart {
pub content_id: String,
}
pub(crate) struct CloudMultiPartUpload<T>
where
T: CloudMultiPartUploadImpl,
{
inner: Arc<T>,
completed_parts: Vec<Option<UploadPart>>,
tasks: FuturesUnordered<BoxedTryFuture<(usize, UploadPart)>>,
max_concurrency: usize,
current_buffer: Vec<u8>,
min_part_size: usize,
current_part_idx: usize,
completion_task: Option<BoxedTryFuture<()>>,
}
impl<T> CloudMultiPartUpload<T>
where
T: CloudMultiPartUploadImpl,
{
pub fn new(inner: T, max_concurrency: usize) -> Self {
Self {
inner: Arc::new(inner),
completed_parts: Vec::new(),
tasks: FuturesUnordered::new(),
max_concurrency,
current_buffer: Vec::new(),
min_part_size: 5_000_000,
current_part_idx: 0,
completion_task: None,
}
}
pub fn poll_tasks(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Result<(), io::Error> {
if self.tasks.is_empty() {
return Ok(());
}
let total_parts = self.completed_parts.len();
while let Poll::Ready(Some(res)) = self.tasks.poll_next_unpin(cx) {
let (part_idx, part) = res?;
self.completed_parts
.resize(std::cmp::max(part_idx + 1, total_parts), None);
self.completed_parts[part_idx] = Some(part);
}
Ok(())
}
}
impl<T> AsyncWrite for CloudMultiPartUpload<T>
where
T: CloudMultiPartUploadImpl + Send + Sync,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, io::Error>> {
self.as_mut().poll_tasks(cx)?;
let enough_to_send = (buf.len() + self.current_buffer.len()) > self.min_part_size;
if enough_to_send && self.tasks.len() < self.max_concurrency {
self.current_buffer.extend_from_slice(buf);
let out_buffer = std::mem::take(&mut self.current_buffer);
let task = self
.inner
.put_multipart_part(out_buffer, self.current_part_idx);
self.tasks.push(task);
self.current_part_idx += 1;
self.as_mut().poll_tasks(cx)?;
Poll::Ready(Ok(buf.len()))
} else if !enough_to_send {
self.current_buffer.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
} else {
Poll::Pending
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
self.as_mut().poll_tasks(cx)?;
if !self.current_buffer.is_empty() && self.tasks.len() < self.max_concurrency {
let out_buffer: Vec<u8> = std::mem::take(&mut self.current_buffer);
let task = self
.inner
.put_multipart_part(out_buffer, self.current_part_idx);
self.tasks.push(task);
}
self.as_mut().poll_tasks(cx)?;
if self.tasks.is_empty() && self.current_buffer.is_empty() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
match self.as_mut().poll_flush(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(res) => res?,
};
let parts = std::mem::take(&mut self.completed_parts);
let inner = Arc::clone(&self.inner);
let completion_task = self
.completion_task
.get_or_insert_with(|| inner.complete(parts));
Pin::new(completion_task).poll(cx)
}
}