1use std::pin::Pin;
8use std::task::{Context, Poll};
9
10use bytes::{Bytes, BytesMut};
11use futures::{Stream, StreamExt};
12use s3s::StdError;
13use s3s::dto::StreamingBlob;
14use s3s::stream::{ByteStream, RemainingLength};
15
16pub async fn collect_blob(blob: StreamingBlob, max_bytes: usize) -> Result<Bytes, BlobError> {
18 let hint = blob.remaining_length().exact().unwrap_or(0).min(max_bytes);
19 let mut buf = BytesMut::with_capacity(hint);
20 let mut stream = blob;
21 while let Some(chunk) = stream.next().await {
22 let chunk = chunk.map_err(|e| BlobError::Read(format!("{e}")))?;
23 if buf.len().saturating_add(chunk.len()) > max_bytes {
24 return Err(BlobError::Oversized {
25 limit: max_bytes,
26 seen_at_least: buf.len() + chunk.len(),
27 });
28 }
29 buf.extend_from_slice(&chunk);
30 }
31 Ok(buf.freeze())
32}
33
34pub fn bytes_to_blob(bytes: Bytes) -> StreamingBlob {
41 StreamingBlob::new(SingleChunkBlob(Some(bytes)))
42}
43
44struct SingleChunkBlob(Option<Bytes>);
48
49impl Stream for SingleChunkBlob {
50 type Item = Result<Bytes, StdError>;
51 fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
52 Poll::Ready(self.get_mut().0.take().map(Ok))
53 }
54 fn size_hint(&self) -> (usize, Option<usize>) {
55 match &self.0 {
56 Some(_) => (1, Some(1)),
57 None => (0, Some(0)),
58 }
59 }
60}
61
62impl ByteStream for SingleChunkBlob {
63 fn remaining_length(&self) -> RemainingLength {
64 match &self.0 {
65 Some(b) => RemainingLength::new_exact(b.len()),
66 None => RemainingLength::new_exact(0),
67 }
68 }
69}
70
71#[derive(Debug, thiserror::Error)]
72pub enum BlobError {
73 #[error("body exceeded configured limit ({limit} bytes); saw at least {seen_at_least}")]
74 Oversized { limit: usize, seen_at_least: usize },
75 #[error("error reading streaming body: {0}")]
76 Read(String),
77}
78
79pub async fn peek_sample(
83 mut blob: StreamingBlob,
84 sample_bytes: usize,
85) -> Result<(Bytes, StreamingBlob), BlobError> {
86 let mut sample = BytesMut::with_capacity(sample_bytes);
87 let mut leftover: Option<Bytes> = None;
88 while sample.len() < sample_bytes {
89 match blob.next().await {
90 Some(Ok(chunk)) => {
91 let remaining = sample_bytes.saturating_sub(sample.len());
92 if chunk.len() <= remaining {
93 sample.extend_from_slice(&chunk);
94 } else {
95 sample.extend_from_slice(&chunk[..remaining]);
96 leftover = Some(chunk.slice(remaining..));
97 break;
98 }
99 }
100 Some(Err(e)) => return Err(BlobError::Read(format!("{e}"))),
101 None => break,
102 }
103 }
104 let sample_bytes = sample.freeze();
105 let rest = chain_leftover_with_blob(leftover, blob);
106 Ok((sample_bytes, rest))
107}
108
109pub fn chain_sample_with_rest(sample: Bytes, rest: StreamingBlob) -> StreamingBlob {
112 let head = futures::stream::once(async move { Ok::<_, std::io::Error>(sample) });
113 let tail = rest.map(|r| r.map_err(|e| std::io::Error::other(e.to_string())));
114 StreamingBlob::wrap(head.chain(tail))
115}
116
117fn chain_leftover_with_blob(leftover: Option<Bytes>, rest: StreamingBlob) -> StreamingBlob {
118 match leftover {
119 Some(b) => chain_sample_with_rest(b, rest),
120 None => rest,
121 }
122}
123
124pub async fn collect_with_sample(
126 sample: Bytes,
127 rest: StreamingBlob,
128 max_bytes: usize,
129) -> Result<Bytes, BlobError> {
130 if sample.len() > max_bytes {
131 return Err(BlobError::Oversized {
132 limit: max_bytes,
133 seen_at_least: sample.len(),
134 });
135 }
136 let mut buf = BytesMut::with_capacity(sample.len() + 4096);
137 buf.extend_from_slice(&sample);
138 let mut stream = rest;
139 while let Some(chunk) = stream.next().await {
140 let chunk = chunk.map_err(|e| BlobError::Read(format!("{e}")))?;
141 if buf.len().saturating_add(chunk.len()) > max_bytes {
142 return Err(BlobError::Oversized {
143 limit: max_bytes,
144 seen_at_least: buf.len() + chunk.len(),
145 });
146 }
147 buf.extend_from_slice(&chunk);
148 }
149 Ok(buf.freeze())
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[tokio::test]
157 async fn collect_roundtrip() {
158 let original = Bytes::from_static(b"hello squished s3");
159 let blob = bytes_to_blob(original.clone());
160 let collected = collect_blob(blob, 1024).await.unwrap();
161 assert_eq!(collected, original);
162 }
163
164 #[tokio::test]
165 async fn collect_rejects_oversized() {
166 let big = Bytes::from(vec![0u8; 2048]);
167 let blob = bytes_to_blob(big);
168 let err = collect_blob(blob, 1024).await.unwrap_err();
169 assert!(matches!(err, BlobError::Oversized { .. }));
170 }
171}