use async_trait::async_trait;
use futures::{stream::FuturesUnordered, Future, StreamExt};
use std::{io, pin::Pin, sync::Arc, task::Poll};
use tokio::io::AsyncWrite;
use crate::Result;
type BoxedTryFuture<T> = Pin<Box<dyn Future<Output = Result<T, io::Error>> + Send>>;
#[async_trait]
pub trait PutPart: Send + Sync + 'static {
async fn put_part(&self, buf: Vec<u8>, part_idx: usize) -> Result<PartId>;
async fn complete(&self, completed_parts: Vec<PartId>) -> Result<()>;
}
#[derive(Debug, Clone)]
pub struct PartId {
pub content_id: String,
}
pub struct WriteMultiPart<T: PutPart> {
inner: Arc<T>,
completed_parts: Vec<Option<PartId>>,
tasks: FuturesUnordered<BoxedTryFuture<(usize, PartId)>>,
max_concurrency: usize,
current_buffer: Vec<u8>,
part_size: usize,
current_part_idx: usize,
completion_task: Option<BoxedTryFuture<()>>,
}
impl<T: PutPart> WriteMultiPart<T> {
pub fn new(inner: T, max_concurrency: usize) -> Self {
Self {
inner: Arc::new(inner),
completed_parts: Vec::new(),
tasks: FuturesUnordered::new(),
max_concurrency,
current_buffer: Vec::new(),
part_size: 10 * 1024 * 1024,
current_part_idx: 0,
completion_task: None,
}
}
fn add_to_buffer(mut self: Pin<&mut Self>, buf: &[u8], offset: usize) -> usize {
let remaining_capacity = self.part_size - self.current_buffer.len();
let to_copy = std::cmp::min(remaining_capacity, buf.len() - offset);
self.current_buffer
.extend_from_slice(&buf[offset..offset + to_copy]);
to_copy
}
fn poll_tasks(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Result<(), io::Error> {
if self.tasks.is_empty() {
return Ok(());
}
while let Poll::Ready(Some(res)) = self.tasks.poll_next_unpin(cx) {
let (part_idx, part) = res?;
let total_parts = self.completed_parts.len();
self.completed_parts
.resize(std::cmp::max(part_idx + 1, total_parts), None);
self.completed_parts[part_idx] = Some(part);
}
Ok(())
}
fn final_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
self.as_mut().poll_tasks(cx)?;
if !self.current_buffer.is_empty() && self.tasks.len() < self.max_concurrency {
let out_buffer: Vec<u8> = std::mem::take(&mut self.current_buffer);
let inner = Arc::clone(&self.inner);
let part_idx = self.current_part_idx;
self.tasks.push(Box::pin(async move {
let upload_part = inner.put_part(out_buffer, part_idx).await?;
Ok((part_idx, upload_part))
}));
}
self.as_mut().poll_tasks(cx)?;
if self.tasks.is_empty() && self.current_buffer.is_empty() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
impl<T: PutPart> AsyncWrite for WriteMultiPart<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.as_mut().poll_tasks(cx)?;
let mut offset = 0;
loop {
offset += self.as_mut().add_to_buffer(buf, offset);
if self.current_buffer.len() < self.part_size
|| self.tasks.len() >= self.max_concurrency
{
break;
}
let new_buffer = Vec::with_capacity(self.part_size);
let out_buffer = std::mem::replace(&mut self.current_buffer, new_buffer);
let inner = Arc::clone(&self.inner);
let part_idx = self.current_part_idx;
self.tasks.push(Box::pin(async move {
let upload_part = inner.put_part(out_buffer, part_idx).await?;
Ok((part_idx, upload_part))
}));
self.current_part_idx += 1;
self.as_mut().poll_tasks(cx)?;
}
if offset == 0 && !buf.is_empty() {
Poll::Pending
} else {
Poll::Ready(Ok(offset))
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
self.as_mut().poll_tasks(cx)?;
if self.tasks.is_empty() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
match self.as_mut().final_flush(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(res) => res?,
};
let parts = std::mem::take(&mut self.completed_parts);
let parts = parts
.into_iter()
.enumerate()
.map(|(idx, part)| {
part.ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
format!("Missing information for upload part {idx}"),
)
})
})
.collect::<Result<_, _>>()?;
let inner = Arc::clone(&self.inner);
let completion_task = self.completion_task.get_or_insert_with(|| {
Box::pin(async move {
inner.complete(parts).await?;
Ok(())
})
});
Pin::new(completion_task).poll(cx)
}
}
impl<T: PutPart> std::fmt::Debug for WriteMultiPart<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WriteMultiPart")
.field("completed_parts", &self.completed_parts)
.field("tasks", &self.tasks)
.field("max_concurrency", &self.max_concurrency)
.field("current_buffer", &self.current_buffer)
.field("part_size", &self.part_size)
.field("current_part_idx", &self.current_part_idx)
.finish()
}
}