use crate::crypto::Checksum as _;
use crate::crypto::Sha256;
use crate::error::StdError;
use crate::stream::{ByteStream, DynByteStream, RemainingLength};
use bytes::Bytes;
use futures::Stream;
use std::fmt;
use std::pin::Pin;
use std::task::{Context, Poll};
use thiserror::Error;
pub struct UploadStream<S> {
inner: S,
hasher: Option<Sha256>,
expected_sha256: [u8; 32],
remaining_length: usize,
state: State,
}
impl<S> fmt::Debug for UploadStream<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("UploadStream")
.field("remaining_length", &self.remaining_length)
.field("state", &self.state)
.finish_non_exhaustive()
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum State {
Reading,
AwaitingEof,
Finished,
Failed,
}
#[derive(Debug, Error)]
pub enum UploadStreamError {
#[error("UploadStreamError: Underlying: {0}")]
Underlying(StdError),
#[error("UploadStreamError: Sha256Mismatch")]
Sha256Mismatch,
#[error("UploadStreamError: LengthMismatch")]
LengthMismatch,
#[error("UploadStreamError: Incomplete")]
Incomplete,
#[error("UploadStreamError: InvalidChecksum")]
InvalidChecksum,
}
impl<S> UploadStream<S> {
pub fn new(inner: S, length: usize, hex_sha256: &str) -> Result<Self, UploadStreamError> {
let expected_sha256 = decode_sha256_hex(hex_sha256)?;
Ok(Self {
inner,
hasher: Some(Sha256::new()),
expected_sha256,
remaining_length: length,
state: State::Reading,
})
}
pub fn into_byte_stream(self) -> DynByteStream
where
Self: Sized + ByteStream + Stream<Item = Result<Bytes, UploadStreamError>> + Send + Sync + Unpin + 'static,
{
crate::stream::into_dyn(self)
}
fn finalize_if_complete(&mut self) -> Result<bool, UploadStreamError> {
if self.remaining_length != 0 {
return Ok(false);
}
match self.finalize_hash() {
Ok(()) => {
self.state = State::AwaitingEof;
Ok(true)
}
Err(err) => {
self.state = State::Failed;
Err(err)
}
}
}
fn finalize_hash(&mut self) -> Result<(), UploadStreamError> {
if self.hasher.is_none() {
return Ok(());
}
let digest = self.hasher.take().unwrap().finalize();
if digest == self.expected_sha256 {
Ok(())
} else {
Err(UploadStreamError::Sha256Mismatch)
}
}
}
impl<S> Stream for UploadStream<S>
where
S: Stream<Item = Result<Bytes, StdError>> + Unpin,
{
type Item = Result<Bytes, UploadStreamError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match self.state {
State::Reading => {
match self.finalize_if_complete() {
Ok(true) => continue,
Ok(false) => {}
Err(err) => return Poll::Ready(Some(Err(err))),
}
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
if bytes.is_empty() {
continue;
}
let len = bytes.len();
if len > self.remaining_length {
self.state = State::Failed;
return Poll::Ready(Some(Err(UploadStreamError::LengthMismatch)));
}
if let Some(hasher) = &mut self.hasher {
hasher.update(bytes.as_ref());
}
self.remaining_length -= len;
match self.finalize_if_complete() {
Ok(true | false) => {}
Err(err) => return Poll::Ready(Some(Err(err))),
}
return Poll::Ready(Some(Ok(bytes)));
}
Poll::Ready(Some(Err(err))) => {
self.state = State::Failed;
return Poll::Ready(Some(Err(UploadStreamError::Underlying(err))));
}
Poll::Ready(None) => {
self.state = State::Failed;
return Poll::Ready(Some(Err(UploadStreamError::Incomplete)));
}
Poll::Pending => return Poll::Pending,
}
}
State::AwaitingEof => match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
if bytes.is_empty() {
continue;
}
self.state = State::Failed;
return Poll::Ready(Some(Err(UploadStreamError::LengthMismatch)));
}
Poll::Ready(Some(Err(err))) => {
self.state = State::Failed;
return Poll::Ready(Some(Err(UploadStreamError::Underlying(err))));
}
Poll::Ready(None) => {
self.state = State::Finished;
return Poll::Ready(None);
}
Poll::Pending => return Poll::Pending,
},
State::Finished | State::Failed => return Poll::Ready(None),
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(0, None)
}
}
impl<S> ByteStream for UploadStream<S>
where
S: Stream<Item = Result<Bytes, StdError>> + Unpin,
{
fn remaining_length(&self) -> RemainingLength {
RemainingLength::new_exact(self.remaining_length)
}
}
fn decode_sha256_hex(expected_sha256: &str) -> Result<[u8; 32], UploadStreamError> {
if expected_sha256.len() != 64 {
return Err(UploadStreamError::InvalidChecksum);
}
let mut out = [0_u8; 32];
match hex_simd::decode(expected_sha256.as_bytes(), hex_simd::Out::from_slice(&mut out)) {
Ok(_) => Ok(out),
Err(_) => Err(UploadStreamError::InvalidChecksum),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::crypto::hex_sha256_string;
use futures::StreamExt as _;
use std::io;
#[allow(clippy::unnecessary_wraps)]
fn ok_bytes(data: &'static [u8]) -> Result<Bytes, StdError> {
Ok(Bytes::from_static(data))
}
#[tokio::test]
async fn single_chunk_success() {
let data = b"hello world";
let checksum = hex_sha256_string(data);
let stream = futures::stream::iter(vec![ok_bytes(data)]);
let mut upload = UploadStream::new(stream, data.len(), &checksum).unwrap();
let chunk = upload.next().await.unwrap().unwrap();
assert_eq!(chunk.as_ref(), data);
assert!(upload.next().await.is_none());
}
#[tokio::test]
async fn sha256_mismatch() {
let data = b"hello";
let checksum = hex_sha256_string(b"world");
let stream = futures::stream::iter(vec![ok_bytes(data)]);
let mut upload = UploadStream::new(stream, data.len(), &checksum).unwrap();
let err = upload.next().await.unwrap().unwrap_err();
assert!(matches!(err, UploadStreamError::Sha256Mismatch));
}
#[tokio::test]
async fn length_mismatch_extra_bytes() {
let data = b"abcdef";
let checksum = hex_sha256_string(data);
let chunks = vec![ok_bytes(b"abc"), ok_bytes(b"def")];
let stream = futures::stream::iter(chunks);
let mut upload = UploadStream::new(stream, 5, &checksum).unwrap();
let first = upload.next().await.unwrap().unwrap();
assert_eq!(first.as_ref(), b"abc");
let err = upload.next().await.unwrap().unwrap_err();
assert!(matches!(err, UploadStreamError::LengthMismatch));
}
#[tokio::test]
async fn incomplete_stream() {
let data = b"abcdef";
let checksum = hex_sha256_string(data);
let chunks = vec![ok_bytes(b"abc")];
let stream = futures::stream::iter(chunks);
let mut upload = UploadStream::new(stream, data.len(), &checksum).unwrap();
let first = upload.next().await.unwrap().unwrap();
assert_eq!(first.as_ref(), b"abc");
let err = upload.next().await.unwrap().unwrap_err();
assert!(matches!(err, UploadStreamError::Incomplete));
}
#[tokio::test]
async fn zero_length_success() {
let checksum = hex_sha256_string(b"");
let stream = futures::stream::iter(Vec::<Result<Bytes, StdError>>::new());
let mut upload = UploadStream::new(stream, 0, &checksum).unwrap();
assert!(upload.next().await.is_none());
}
#[tokio::test]
async fn invalid_checksum_hex() {
let stream = futures::stream::iter(Vec::<Result<Bytes, StdError>>::new());
let err = UploadStream::new(stream, 0, "zz").unwrap_err();
assert!(matches!(err, UploadStreamError::InvalidChecksum));
}
#[tokio::test]
async fn extra_payload_after_completion() {
let data = b"abc";
let checksum = hex_sha256_string(data);
let chunks = vec![ok_bytes(data), ok_bytes(b"extra")];
let stream = futures::stream::iter(chunks);
let mut upload = UploadStream::new(stream, data.len(), &checksum).unwrap();
let first = upload.next().await.unwrap().unwrap();
assert_eq!(first.as_ref(), data);
let err = upload.next().await.unwrap().unwrap_err();
assert!(matches!(err, UploadStreamError::LengthMismatch));
}
#[tokio::test]
async fn propagate_underlying_error() {
let checksum = hex_sha256_string(b"");
let err: Result<Bytes, StdError> = Err(Box::new(io::Error::other("boom")));
let stream = futures::stream::iter(vec![err]);
let mut upload = UploadStream::new(stream, 0, &checksum).unwrap();
let err = upload.next().await.unwrap().unwrap_err();
assert!(matches!(err, UploadStreamError::Underlying(_)));
}
}