use crate::encryption::ChunkSizeKb;
use crate::stream::{EncStreamReader, LastStreamElement, StreamChunk};
use crate::value::EncValueHeader;
use crate::CryptrError;
use async_trait::async_trait;
use bytes::BytesMut;
use flume::Sender;
use futures::channel::oneshot;
use futures::{pin_mut, StreamExt};
use s3_simple::Bucket;
use std::fmt::Formatter;
use std::time::Duration;
use tokio::sync::watch;
use tokio::task::JoinHandle;
use tokio::time::Instant;
use tokio::{sync, time};
use tracing::{debug, error};
#[derive(Debug)]
pub struct S3Reader<'a> {
pub bucket: &'a Bucket,
pub object: &'a str,
pub print_progress: bool,
}
#[async_trait]
impl EncStreamReader for S3Reader<'_> {
fn debug_reader(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"S3Reader(Bucket: {}, Object: {})",
self.bucket.name, self.object,
)
}
#[tracing::instrument]
async fn spawn_reader_encryption(
self,
chunk_size: ChunkSizeKb,
tx: Sender<Result<(LastStreamElement, StreamChunk), CryptrError>>,
) -> Result<JoinHandle<Result<(), CryptrError>>, CryptrError> {
let head = self.bucket.head(self.object).await?;
let content_length = head
.content_length
.map(|c| c as usize)
.unwrap_or(usize::MAX);
let tx_progress =
Self::spawn_progress(self.print_progress, self.object, content_length).await;
let resp = self.bucket.get(self.object).await?;
debug!("resp: {resp:?}");
let handle = tokio::spawn(async move {
let stream = resp.bytes_stream();
pin_mut!(stream);
debug!("stream pinned");
let mut data = stream.next().await;
if let Some(Err(err)) = &data {
let msg = format!("S3 bucket error: {err}");
tx.send_async(Err(CryptrError::S3(msg.clone()))).await?;
return Err(CryptrError::S3(msg));
}
let chunk_size = chunk_size.value_bytes() as usize;
let mut buf = BytesMut::with_capacity(chunk_size);
let mut total = 0;
loop {
let bytes = data.unwrap().unwrap();
total += bytes.len();
buf.extend(bytes);
debug!("buf len: {:?}", buf.len());
let _ = tx_progress.send(total);
data = stream.next().await;
match &data {
None => {
debug!("sending last element with len: {}", buf.len());
tx.send_async(Ok((LastStreamElement::Yes, StreamChunk(buf.to_vec()))))
.await?;
break;
}
Some(res) => {
if res.is_err() {
debug!("stream rest in loop error: {res:?}");
tx.send_async(Err(CryptrError::S3(format!("{res:?}"))))
.await?;
return Err(CryptrError::S3(format!("{res:?}")));
}
}
}
if buf.len() > chunk_size {
let bytes = buf.split_to(chunk_size);
debug!(
"sending non-last chunk with len: {} with data left in buf: {}",
bytes.len(),
buf.len()
);
tx.send_async(Ok((LastStreamElement::No, StreamChunk(bytes.to_vec()))))
.await?;
}
}
debug!("Read {} bytes", total);
Ok(())
});
Ok(handle)
}
#[tracing::instrument]
async fn spawn_reader_decryption(
self,
tx_init: oneshot::Sender<(EncValueHeader, Vec<u8>)>,
tx: Sender<Result<(LastStreamElement, StreamChunk), CryptrError>>,
) -> Result<JoinHandle<Result<(), CryptrError>>, CryptrError> {
let head = self.bucket.head(self.object).await?;
let content_length = head
.content_length
.map(|c| c as usize)
.unwrap_or(usize::MAX);
let tx_progress =
Self::spawn_progress(self.print_progress, self.object, content_length).await;
let resp = self.bucket.get(self.object).await?;
let (tx_init_internal, rx_init) = flume::unbounded();
tokio::spawn(async move {
match rx_init.recv_async().await {
Ok(payload) => {
tx_init.send(payload).expect("tx_init to work properly");
}
Err(err) => {
error!("tx_init closed in reader: {err:?}");
}
}
});
let handle = tokio::spawn(async move {
let stream = resp.bytes_stream();
pin_mut!(stream);
debug!("stream pinned");
let mut data = stream.next().await;
if let Some(Err(err)) = &data {
let msg = format!("S3 bucket error: {err}");
tx.send_async(Err(CryptrError::S3(msg.clone()))).await?;
return Err(CryptrError::S3(msg));
}
let mut header = None;
let mut chunk_size = 0;
let mut buf = BytesMut::with_capacity(chunk_size);
let mut total = 0;
loop {
let bytes = data.unwrap().unwrap();
total += bytes.len();
buf.extend(bytes);
debug!("buf len: {:?}", buf.len());
let _ = tx_progress.send(total);
if header.is_none() {
let (enc_header, nonce, payload_offset) =
match EncValueHeader::try_extract_with_nonce(buf.as_ref()) {
Ok(d) => d,
Err(err) => {
let msg = format!(
"Error extracting encryption header from first chunk: {err:?}"
);
tx.send_async(Err(CryptrError::S3(msg.clone()))).await?;
return Err(CryptrError::S3(msg));
}
};
debug!(
"Extracted header data from first chunk: {enc_header:?} with \
payload_offset: {payload_offset}"
);
tx_init_internal
.send((enc_header.clone(), nonce))
.expect("tx_init_internal to be only called once");
let _header_bytes = buf.split_to(payload_offset as usize);
chunk_size =
enc_header.chunk_size.value_bytes_with_mac(&enc_header.alg) as usize;
header = Some(enc_header);
}
data = stream.next().await;
let is_stream_empty = match &data {
None => true,
Some(res) => {
if res.is_err() {
debug!("stream rest in loop error: {res:?}");
tx.send_async(Err(CryptrError::S3(format!("{res:?}"))))
.await?;
return Err(CryptrError::S3(format!("{res:?}")));
}
false
}
};
while buf.len() > chunk_size {
let bytes = buf.split_to(chunk_size);
debug!(
"sending non-last chunk with len: {} with data left in buf: {}",
bytes.len(),
buf.len()
);
tx.send_async(Ok((LastStreamElement::No, StreamChunk(bytes.to_vec()))))
.await?;
}
if is_stream_empty {
debug!("sending last element with len: {}", buf.len());
tx.send_async(Ok((LastStreamElement::Yes, StreamChunk(buf.to_vec()))))
.await?;
break;
}
}
debug!("Read {total} bytes");
Ok(())
});
Ok(handle)
}
}
impl S3Reader<'_> {
async fn spawn_progress(
print_progress: bool,
object: &str,
content_length: usize,
) -> watch::Sender<usize> {
let (tx_progress, rx_progess) = sync::watch::channel(0);
if print_progress {
let object = object.to_string();
tokio::spawn(async move {
let (div, unit) = if content_length > 1024 * 1024 * 10 {
((1024 * 1024) as f64, "MiB")
} else if content_length > 1024 * 10 {
((1024 * 10) as f64, "KiB")
} else {
(1f64, "Bytes")
};
let target = content_length as f64 / div;
let start = Instant::now();
let mut interval = time::interval(Duration::from_secs(5));
interval.tick().await;
loop {
interval.tick().await;
let progress = *rx_progess.borrow() as f64 / div;
let rate = progress / start.elapsed().as_secs() as f64;
println!(
"S3Reader ({object}) {progress:.02} / {target:.02} {unit} \
-> {rate:.02} {unit}/s"
);
if progress >= target {
break;
}
}
});
}
tx_progress
}
}