#![allow(dead_code, clippy::must_use_candidate)]
use bytes::{Bytes, BytesMut};
use futures::{stream, Stream, StreamExt, TryStreamExt};
use md5::{Digest, Md5};
use rusoto_core::{HttpClient, RusotoError};
use rusoto_credential::StaticProvider;
use rusoto_s3::{
CompleteMultipartUploadRequest, CompletedMultipartUpload, CompletedPart, CreateMultipartUploadRequest, GetObjectRequest, HeadObjectRequest, ListObjectsV2Error, ListObjectsV2Output, ListObjectsV2Request, Object, S3Client, UploadPartRequest, S3 as _
};
use std::{error::Error, future::Future, io, str, time::Duration};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt};
pub type S3 = S3Client;
pub type Region = rusoto_core::Region;
pub fn s3_new(access_key_id: &str, secret_access_key: &str, region: Region) -> S3 {
S3Client::new_with(
HttpClient::new().expect("failed to create request dispatcher"),
StaticProvider::new(access_key_id.to_owned(), secret_access_key.to_owned(), None, None),
region,
)
}
#[allow(clippy::too_many_lines)] pub async fn upload(
s3_client: &S3, bucket: String, key: String, content_encoding: Option<String>, content_type: String, cache_control: Option<u64>,
read: impl AsyncRead,
) -> Result<(), Box<dyn Error>> {
let pb = &indicatif::ProgressBar::new(0).with_finish(indicatif::ProgressFinish::AndLeave);
pb.set_style(
indicatif::ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")
.unwrap()
.progress_chars("#>-"),
);
let part_size = 16 * 1024 * 1024;
assert!((5_242_880..5_368_709_120).contains(&part_size)); let parallelism = 10;
let upload_id = &rusoto_retry(|| async {
s3_client
.create_multipart_upload(CreateMultipartUploadRequest {
bucket: bucket.clone(),
key: key.clone(),
content_encoding: content_encoding.clone(),
content_type: Some(content_type.clone()),
cache_control: Some(
cache_control.map_or_else(|| String::from("no-store, must-revalidate"), |secs| format!("public, max-age={}", secs)),
),
..CreateMultipartUploadRequest::default()
})
.await
})
.await?
.upload_id
.unwrap();
tokio::pin!(read);
let (bucket, key) = (&bucket, &key);
let e_tags = futures::stream::unfold(read, |mut read| async move {
let mut buf = BytesMut::with_capacity(part_size);
loop {
let n = match read.read_buf(&mut buf).await {
Ok(n) => n,
Err(e) => return Some((Err(e), read)),
};
assert!(buf.len() <= part_size);
if n == 0 || buf.len() == part_size {
break;
}
}
pb.inc_length(buf.len().try_into().unwrap());
(!buf.is_empty()).then(|| (Ok::<_, io::Error>(buf.freeze()), read))
})
.enumerate()
.map(|(i, buf): (usize, Result<Bytes, _>)| async move {
let buf = &buf?;
assert!(i < 10_000); Ok::<_, io::Error>(
rusoto_retry(|| async {
let buf = buf.clone();
let buf_len = buf.len();
let ret = s3_client
.upload_part(UploadPartRequest {
bucket: bucket.clone(),
key: key.clone(),
upload_id: upload_id.clone(),
part_number: (i + 1).try_into().unwrap(),
content_md5: Some(base64::encode(&Md5::new().chain_update(&buf).finalize())),
body: Some(rusoto_core::ByteStream::new_with_size(stream::once(async { Ok(buf) }), buf_len)),
..UploadPartRequest::default()
})
.await;
pb.inc(buf_len.try_into().unwrap());
ret
})
.await
.unwrap()
.e_tag
.unwrap(),
)
})
.buffered(parallelism)
.try_collect::<Vec<String>>()
.await?;
let _ = rusoto_retry(|| async {
s3_client
.complete_multipart_upload(CompleteMultipartUploadRequest {
bucket: bucket.clone(),
key: key.clone(),
upload_id: upload_id.clone(),
multipart_upload: Some(CompletedMultipartUpload {
parts: Some(
e_tags
.iter()
.enumerate()
.map(|(i, e_tag)| CompletedPart { part_number: Some((i + 1).try_into().unwrap()), e_tag: Some(e_tag.clone()) })
.collect(),
),
}),
..CompleteMultipartUploadRequest::default()
})
.await
})
.await?;
Ok(())
}
#[allow(clippy::too_many_lines)] pub async fn download(s3_client: &S3, bucket: String, key: String) -> Result<impl AsyncBufRead + '_, io::Error> {
let head = rusoto_retry(|| async {
s3_client
.head_object(HeadObjectRequest { bucket: bucket.clone(), key: key.clone(), part_number: Some(1), ..HeadObjectRequest::default() })
.await
})
.await
.map_err(|e| io::Error::new(io::ErrorKind::NotFound, e.to_string()))?;
let (part_size, parts): (u64, Option<u64>) = (head.content_length.unwrap().try_into().unwrap(), head.parts_count.map(|x| x.try_into().unwrap()));
let length: u64 = if head.parts_count.is_some() {
rusoto_retry(|| async {
s3_client
.head_object(HeadObjectRequest { bucket: bucket.clone(), key: key.clone(), part_number: None, ..HeadObjectRequest::default() })
.await
})
.await
.unwrap()
.content_length
.unwrap()
.try_into()
.unwrap()
} else {
part_size
};
assert!(part_size * (parts.unwrap_or(1) - 1) < length && length <= part_size * parts.unwrap_or(1));
let pb = indicatif::ProgressBar::new(length).with_finish(indicatif::ProgressFinish::AndLeave);
pb.set_style(
indicatif::ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})")
.unwrap()
.progress_chars("#>-"),
);
if let Some(parts) = parts {
let parallelism = 20;
let body = tokio_util::io::StreamReader::new(
futures::stream::iter((0..parts).map(move |i| {
let (pb, bucket, key) = (pb.clone(), bucket.clone(), key.clone());
async move {
let range = part_size * i..(part_size * (i + 1)).min(length);
let cap: usize = (range.end - range.start).try_into().unwrap();
'async_read: loop {
let body = rusoto_retry(|| async {
s3_client
.get_object(GetObjectRequest {
bucket: bucket.clone(),
key: key.clone(),
part_number: Some((i + 1).try_into().unwrap()),
..GetObjectRequest::default()
})
.await
})
.await
.unwrap()
.body
.unwrap()
.into_async_read();
let mut body = pb.wrap_async_read(body);
let mut buf = BytesMut::with_capacity(cap);
while buf.len() != cap {
let _bytes = match body.read_buf(&mut buf).await {
Ok(bytes) => bytes,
Err(e) => {
println!("Got transient http error: {:?}. Retrying.", e);
continue 'async_read;
}
};
assert!(buf.len() <= cap);
}
break Ok::<_, io::Error>(buf);
}
}
}))
.buffered(parallelism),
);
Ok(tokio_util::either::Either::Left(body))
} else {
let body = rusoto_retry(|| async {
s3_client.get_object(GetObjectRequest { bucket: bucket.clone(), key: key.clone(), ..GetObjectRequest::default() }).await
})
.await
.unwrap()
.body
.unwrap()
.into_async_read();
let body = pb.wrap_async_read(body);
let body = tokio::io::BufReader::with_capacity(16 * 1024 * 1024, body);
Ok(tokio_util::either::Either::Right(body))
}
}
pub async fn rusoto_retry<F, U, T, E>(mut f: F) -> Result<T, RusotoError<E>>
where
F: FnMut() -> U,
U: Future<Output = Result<T, RusotoError<E>>>,
{
loop {
match f().await {
Err(RusotoError::HttpDispatch(e)) => {
println!("Got transient error: {:?}. Retrying.", e);
}
Err(RusotoError::Unknown(response))
if matches!(response.status.as_u16(), 429 | 500 | 502 | 503 | 504 | 509)
|| (response.status == 403 && str::from_utf8(&response.body).unwrap().contains("RequestTimeTooSkewed")) =>
{
println!("Got transient response error: {:?}. Retrying.", response);
}
res => break res,
}
tokio::time::sleep(Duration::from_secs(2)).await;
}
}
type StreamItem = Result<ListObjectsV2Output, RusotoError<ListObjectsV2Error>>;
fn stream_list_objects_v2(request: ListObjectsV2Request, s3_client: &S3Client) -> impl Stream<Item = StreamItem> + '_ {
type State<'c> = Option<(ListObjectsV2Request, &'c S3Client)>;
async fn retrieve_list_fragment(state: State<'_>) -> Option<(StreamItem, State<'_>)> {
let (mut request, s3_client) = state?;
let response = rusoto_retry(|| s3_client.list_objects_v2(request.clone())).await;
match response {
Ok(output) if output.next_continuation_token.is_some() => {
request.continuation_token = output.next_continuation_token.clone();
Some((Ok(output), Some((request, s3_client))))
}
Ok(output) => Some((Ok(output), None)),
Err(error) => Some((Err(error), None)),
}
}
stream::unfold(Some((request, s3_client)), retrieve_list_fragment)
}
pub async fn list_objects(bucket: &str, s3_client: &S3Client) -> Result<impl Iterator<Item = Object>, RusotoError<ListObjectsV2Error>> {
let responses: Vec<Result<ListObjectsV2Output, _>> =
stream_list_objects_v2(ListObjectsV2Request { bucket: bucket.to_string(), ..ListObjectsV2Request::default() }, s3_client).collect().await;
let mut objects = vec![];
for response in responses {
let output = response?;
let contents = output.contents.unwrap();
objects.push(contents);
}
Ok(objects.into_iter().flatten())
}