use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::{Bytes, BytesMut};
use futures::{Stream, StreamExt};
use s3s::StdError;
use s3s::dto::StreamingBlob;
use s3s::stream::{ByteStream, RemainingLength};
pub async fn collect_blob(blob: StreamingBlob, max_bytes: usize) -> Result<Bytes, BlobError> {
let hint = blob.remaining_length().exact().unwrap_or(0).min(max_bytes);
let mut buf = BytesMut::with_capacity(hint);
let mut stream = blob;
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| BlobError::Read(format!("{e}")))?;
if buf.len().saturating_add(chunk.len()) > max_bytes {
return Err(BlobError::Oversized {
limit: max_bytes,
seen_at_least: buf.len() + chunk.len(),
});
}
buf.extend_from_slice(&chunk);
}
Ok(buf.freeze())
}
pub fn bytes_to_blob(bytes: Bytes) -> StreamingBlob {
StreamingBlob::new(SingleChunkBlob(Some(bytes)))
}
struct SingleChunkBlob(Option<Bytes>);
impl Stream for SingleChunkBlob {
type Item = Result<Bytes, StdError>;
fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(self.get_mut().0.take().map(Ok))
}
fn size_hint(&self) -> (usize, Option<usize>) {
match &self.0 {
Some(_) => (1, Some(1)),
None => (0, Some(0)),
}
}
}
impl ByteStream for SingleChunkBlob {
fn remaining_length(&self) -> RemainingLength {
match &self.0 {
Some(b) => RemainingLength::new_exact(b.len()),
None => RemainingLength::new_exact(0),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum BlobError {
#[error("body exceeded configured limit ({limit} bytes); saw at least {seen_at_least}")]
Oversized { limit: usize, seen_at_least: usize },
#[error("error reading streaming body: {0}")]
Read(String),
}
pub async fn peek_sample(
mut blob: StreamingBlob,
sample_bytes: usize,
) -> Result<(Bytes, StreamingBlob), BlobError> {
let mut sample = BytesMut::with_capacity(sample_bytes);
let mut leftover: Option<Bytes> = None;
while sample.len() < sample_bytes {
match blob.next().await {
Some(Ok(chunk)) => {
let remaining = sample_bytes.saturating_sub(sample.len());
if chunk.len() <= remaining {
sample.extend_from_slice(&chunk);
} else {
sample.extend_from_slice(&chunk[..remaining]);
leftover = Some(chunk.slice(remaining..));
break;
}
}
Some(Err(e)) => return Err(BlobError::Read(format!("{e}"))),
None => break,
}
}
let sample_bytes = sample.freeze();
let rest = chain_leftover_with_blob(leftover, blob);
Ok((sample_bytes, rest))
}
pub fn chain_sample_with_rest(sample: Bytes, rest: StreamingBlob) -> StreamingBlob {
let head = futures::stream::once(async move { Ok::<_, std::io::Error>(sample) });
let tail = rest.map(|r| r.map_err(|e| std::io::Error::other(e.to_string())));
StreamingBlob::wrap(head.chain(tail))
}
fn chain_leftover_with_blob(leftover: Option<Bytes>, rest: StreamingBlob) -> StreamingBlob {
match leftover {
Some(b) => chain_sample_with_rest(b, rest),
None => rest,
}
}
pub async fn collect_with_sample(
sample: Bytes,
rest: StreamingBlob,
max_bytes: usize,
) -> Result<Bytes, BlobError> {
if sample.len() > max_bytes {
return Err(BlobError::Oversized {
limit: max_bytes,
seen_at_least: sample.len(),
});
}
let mut buf = BytesMut::with_capacity(sample.len() + 4096);
buf.extend_from_slice(&sample);
let mut stream = rest;
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| BlobError::Read(format!("{e}")))?;
if buf.len().saturating_add(chunk.len()) > max_bytes {
return Err(BlobError::Oversized {
limit: max_bytes,
seen_at_least: buf.len() + chunk.len(),
});
}
buf.extend_from_slice(&chunk);
}
Ok(buf.freeze())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn collect_roundtrip() {
let original = Bytes::from_static(b"hello squished s3");
let blob = bytes_to_blob(original.clone());
let collected = collect_blob(blob, 1024).await.unwrap();
assert_eq!(collected, original);
}
#[tokio::test]
async fn collect_rejects_oversized() {
let big = Bytes::from(vec![0u8; 2048]);
let blob = bytes_to_blob(big);
let err = collect_blob(blob, 1024).await.unwrap_err();
assert!(matches!(err, BlobError::Oversized { .. }));
}
}