#[cfg(feature = "tokio")]
use std::task::{Context, Poll};
#[cfg(feature = "tokio")]
use crate::PutPayloadMut;
use crate::{PutPayload, PutResult, Result};
use async_trait::async_trait;
use futures_util::future::BoxFuture;
#[cfg(feature = "tokio")]
use tokio::task::JoinSet;
pub type UploadPart = BoxFuture<'static, Result<()>>;
#[async_trait]
pub trait MultipartUpload: Send + std::fmt::Debug {
fn put_part(&mut self, data: PutPayload) -> UploadPart;
async fn complete(&mut self) -> Result<PutResult>;
async fn abort(&mut self) -> Result<()>;
}
#[async_trait]
impl<W: MultipartUpload + ?Sized> MultipartUpload for Box<W> {
fn put_part(&mut self, data: PutPayload) -> UploadPart {
(**self).put_part(data)
}
async fn complete(&mut self) -> Result<PutResult> {
(**self).complete().await
}
async fn abort(&mut self) -> Result<()> {
(**self).abort().await
}
}
#[cfg(feature = "tokio")]
#[derive(Debug)]
pub struct WriteMultipart {
upload: Box<dyn MultipartUpload>,
buffer: PutPayloadMut,
chunk_size: usize,
tasks: JoinSet<Result<()>>,
}
#[cfg(feature = "tokio")]
impl WriteMultipart {
pub fn new(upload: Box<dyn MultipartUpload>) -> Self {
Self::new_with_chunk_size(upload, 5 * 1024 * 1024)
}
pub fn new_with_chunk_size(upload: Box<dyn MultipartUpload>, chunk_size: usize) -> Self {
Self {
upload,
chunk_size,
buffer: PutPayloadMut::new(),
tasks: Default::default(),
}
}
pub fn poll_for_capacity(
&mut self,
cx: &mut Context<'_>,
max_concurrency: usize,
) -> Poll<Result<()>> {
while !self.tasks.is_empty() && self.tasks.len() >= max_concurrency {
futures_core::ready!(self.tasks.poll_join_next(cx)).unwrap()??
}
Poll::Ready(Ok(()))
}
pub async fn wait_for_capacity(&mut self, max_concurrency: usize) -> Result<()> {
futures_util::future::poll_fn(|cx| self.poll_for_capacity(cx, max_concurrency)).await
}
pub fn write(&mut self, mut buf: &[u8]) {
while !buf.is_empty() {
let remaining = self.chunk_size - self.buffer.content_length();
let to_read = buf.len().min(remaining);
self.buffer.extend_from_slice(&buf[..to_read]);
if to_read == remaining {
let buffer = std::mem::take(&mut self.buffer);
self.put_part(buffer.into())
}
buf = &buf[to_read..]
}
}
pub fn put(&mut self, mut bytes: bytes::Bytes) {
while !bytes.is_empty() {
let remaining = self.chunk_size - self.buffer.content_length();
if bytes.len() < remaining {
self.buffer.push(bytes);
return;
}
self.buffer.push(bytes.split_to(remaining));
let buffer = std::mem::take(&mut self.buffer);
self.put_part(buffer.into())
}
}
pub(crate) fn put_part(&mut self, part: PutPayload) {
self.tasks.spawn(self.upload.put_part(part));
}
pub async fn abort(mut self) -> Result<()> {
self.tasks.shutdown().await;
self.upload.abort().await
}
pub async fn finish(mut self) -> Result<PutResult> {
if !self.buffer.is_empty() {
let part = std::mem::take(&mut self.buffer);
self.put_part(part.into())
}
self.wait_for_capacity(0).await?;
match self.upload.complete().await {
Err(e) => {
self.tasks.shutdown().await;
self.upload.abort().await?;
Err(e)
}
Ok(result) => Ok(result),
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use futures_util::FutureExt;
use parking_lot::Mutex;
use rand::prelude::StdRng;
use rand::{RngExt, SeedableRng};
use crate::ObjectStoreExt;
use crate::memory::InMemory;
use crate::path::Path;
use crate::throttle::{ThrottleConfig, ThrottledStore};
use super::*;
#[tokio::test]
async fn test_concurrency() {
let config = ThrottleConfig {
wait_put_per_call: Duration::from_millis(1),
..Default::default()
};
let path = Path::from("foo");
let store = ThrottledStore::new(InMemory::new(), config);
let upload = store.put_multipart(&path).await.unwrap();
let mut write = WriteMultipart::new_with_chunk_size(upload, 10);
for _ in 0..20 {
write.write(&[0; 5]);
}
assert!(write.wait_for_capacity(10).now_or_never().is_none());
write.wait_for_capacity(10).await.unwrap()
}
#[derive(Debug, Default)]
struct InstrumentedUpload {
chunks: Arc<Mutex<Vec<PutPayload>>>,
}
#[async_trait]
impl MultipartUpload for InstrumentedUpload {
fn put_part(&mut self, data: PutPayload) -> UploadPart {
self.chunks.lock().push(data);
futures_util::future::ready(Ok(())).boxed()
}
async fn complete(&mut self) -> Result<PutResult> {
Ok(PutResult {
e_tag: None,
version: None,
})
}
async fn abort(&mut self) -> Result<()> {
unimplemented!()
}
}
#[tokio::test]
async fn test_write_multipart() {
let mut rng = StdRng::seed_from_u64(42);
for method in [0.0, 0.5, 1.0] {
for _ in 0..10 {
for chunk_size in [1, 17, 23] {
let upload = Box::<InstrumentedUpload>::default();
let chunks = Arc::clone(&upload.chunks);
let mut write = WriteMultipart::new_with_chunk_size(upload, chunk_size);
let mut expected = Vec::with_capacity(1024);
for _ in 0..50 {
let chunk_size = rng.random_range(0..30);
let data: Vec<_> = (0..chunk_size).map(|_| rng.random()).collect();
expected.extend_from_slice(&data);
match rng.random_bool(method) {
true => write.put(data.into()),
false => write.write(&data),
}
}
write.finish().await.unwrap();
let chunks = chunks.lock();
let actual: Vec<_> = chunks.iter().flatten().flatten().copied().collect();
assert_eq!(expected, actual);
for chunk in chunks.iter().take(chunks.len() - 1) {
assert_eq!(chunk.content_length(), chunk_size)
}
let last_chunk = chunks.last().unwrap().content_length();
assert!(last_chunk <= chunk_size, "{chunk_size}");
}
}
}
}
}