1use std::io;
23use std::pin::Pin;
24use std::task::{Context, Poll};
25
26use async_compression::Level;
27use async_compression::tokio::bufread::ZstdDecoder;
28use async_compression::tokio::write::ZstdEncoder;
29use bytes::Bytes;
30use futures::{Stream, StreamExt};
31use s3s::StdError;
32use s3s::dto::StreamingBlob;
33use s3s::stream::{ByteStream, RemainingLength};
34use s4_codec::{ChunkManifest, CodecError, CodecKind};
35use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader};
36use tokio_util::io::{ReaderStream, StreamReader};
37
38pub fn blob_to_async_read(blob: StreamingBlob) -> impl AsyncRead + Unpin + Send + Sync {
45 let mapped = blob.map(|chunk| chunk.map_err(|e| io::Error::other(e.to_string())));
46 StreamReader::new(mapped)
47}
48
49pub fn async_read_to_blob<R: AsyncRead + Unpin + Send + Sync + 'static>(
51 reader: R,
52) -> StreamingBlob {
53 let stream = ReaderStream::new(reader).map(|res| res.map_err(|e| Box::new(e) as StdError));
54 StreamingBlob::new(StreamWrapper { inner: stream })
55}
56
57pin_project_lite::pin_project! {
58 struct StreamWrapper<S> { #[pin] inner: S }
61}
62
63impl<S> Stream for StreamWrapper<S>
64where
65 S: Stream<Item = Result<Bytes, StdError>>,
66{
67 type Item = Result<Bytes, StdError>;
68 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
69 self.project().inner.poll_next(cx)
70 }
71 fn size_hint(&self) -> (usize, Option<usize>) {
72 self.inner.size_hint()
73 }
74}
75
76impl<S> ByteStream for StreamWrapper<S>
77where
78 S: Stream<Item = Result<Bytes, StdError>> + Send + Sync,
79{
80 fn remaining_length(&self) -> RemainingLength {
81 RemainingLength::unknown()
83 }
84}
85
86pub fn cpu_zstd_decompress_stream(body: StreamingBlob) -> StreamingBlob {
91 let read = blob_to_async_read(body);
92 let decoder = ZstdDecoder::new(BufReader::new(read));
93 async_read_to_blob(decoder)
94}
95
96pub fn supports_streaming_decompress(codec: CodecKind) -> bool {
98 matches!(codec, CodecKind::Passthrough | CodecKind::CpuZstd)
99}
100
101pub fn supports_streaming_compress(codec: CodecKind) -> bool {
102 matches!(codec, CodecKind::Passthrough | CodecKind::CpuZstd)
103}
104
105pub async fn streaming_compress_cpu_zstd(
115 body: StreamingBlob,
116 level: i32,
117) -> Result<(Bytes, ChunkManifest), CodecError> {
118 let mut read = blob_to_async_read(body);
119 let mut compressed_buf: Vec<u8> = Vec::with_capacity(256 * 1024);
120 let mut crc: u32 = 0;
121 let mut total_in: u64 = 0;
122 let mut in_buf = vec![0u8; 64 * 1024];
123
124 {
125 let mut encoder = ZstdEncoder::with_quality(&mut compressed_buf, Level::Precise(level));
126 loop {
127 let n = read.read(&mut in_buf).await.map_err(CodecError::Io)?;
128 if n == 0 {
129 break;
130 }
131 crc = crc32c::crc32c_append(crc, &in_buf[..n]);
132 total_in += n as u64;
133 encoder
134 .write_all(&in_buf[..n])
135 .await
136 .map_err(CodecError::Io)?;
137 }
138 encoder.shutdown().await.map_err(CodecError::Io)?;
139 }
140
141 let compressed_len = compressed_buf.len() as u64;
142 Ok((
143 Bytes::from(compressed_buf),
144 ChunkManifest {
145 codec: CodecKind::CpuZstd,
146 original_size: total_in,
147 compressed_size: compressed_len,
148 crc32c: crc,
149 },
150 ))
151}
152
153pub async fn streaming_passthrough(
155 body: StreamingBlob,
156) -> Result<(Bytes, ChunkManifest), CodecError> {
157 let mut read = blob_to_async_read(body);
158 let mut buf: Vec<u8> = Vec::with_capacity(256 * 1024);
159 let mut crc: u32 = 0;
160 let mut total: u64 = 0;
161 let mut chunk = vec![0u8; 64 * 1024];
162 loop {
163 let n = read.read(&mut chunk).await.map_err(CodecError::Io)?;
164 if n == 0 {
165 break;
166 }
167 crc = crc32c::crc32c_append(crc, &chunk[..n]);
168 total += n as u64;
169 buf.extend_from_slice(&chunk[..n]);
170 }
171 let len = buf.len() as u64;
172 Ok((
173 Bytes::from(buf),
174 ChunkManifest {
175 codec: CodecKind::Passthrough,
176 original_size: total,
177 compressed_size: len,
178 crc32c: crc,
179 },
180 ))
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use bytes::BytesMut;
187 use futures::stream;
188 use futures::stream::StreamExt;
189
190 async fn collect(blob: StreamingBlob) -> Bytes {
191 let mut buf = BytesMut::new();
192 let mut s = blob;
193 while let Some(chunk) = s.next().await {
194 buf.extend_from_slice(&chunk.unwrap());
195 }
196 buf.freeze()
197 }
198
199 fn make_blob(b: Bytes) -> StreamingBlob {
200 let stream = stream::once(async move { Ok::<_, std::io::Error>(b) });
201 StreamingBlob::wrap(stream)
202 }
203
204 #[tokio::test]
205 async fn cpu_zstd_streaming_roundtrip_small() {
206 let original = Bytes::from("the quick brown fox jumps over the lazy dog. ".repeat(100));
207 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
208 let blob = make_blob(Bytes::from(compressed));
209 let out_blob = cpu_zstd_decompress_stream(blob);
210 let out = collect(out_blob).await;
211 assert_eq!(out, original);
212 }
213
214 #[tokio::test]
215 async fn cpu_zstd_streaming_handles_chunked_input() {
216 let original = Bytes::from(vec![b'x'; 1_000_000]);
217 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
218 let mut chunks = Vec::new();
220 for chunk in compressed.chunks(1024) {
221 chunks.push(Ok::<_, std::io::Error>(Bytes::copy_from_slice(chunk)));
222 }
223 let in_stream = stream::iter(chunks);
224 let blob = StreamingBlob::wrap(in_stream);
225 let out_blob = cpu_zstd_decompress_stream(blob);
226 let out = collect(out_blob).await;
227 assert_eq!(out, original);
228 }
229
230 #[tokio::test]
231 async fn streaming_passes_through_for_passthrough() {
232 let original = Bytes::from_static(b"hello");
233 let blob = make_blob(original.clone());
234 let out_blob = async_read_to_blob(blob_to_async_read(blob));
235 let out = collect(out_blob).await;
236 assert_eq!(out, original);
237 }
238
239 #[tokio::test]
240 async fn streaming_compress_then_decompress_roundtrip() {
241 let original = Bytes::from(vec![b'q'; 200_000]);
242 let blob = make_blob(original.clone());
243 let (compressed, manifest) = streaming_compress_cpu_zstd(blob, 3).await.unwrap();
244 assert!(
245 compressed.len() < original.len() / 100,
246 "should be highly compressible"
247 );
248 assert_eq!(manifest.codec, CodecKind::CpuZstd);
249 assert_eq!(manifest.original_size, original.len() as u64);
250 assert_eq!(manifest.compressed_size, compressed.len() as u64);
251 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
253
254 let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
256 let out = collect(decompressed_blob).await;
257 assert_eq!(out, original);
258 }
259
260 #[tokio::test]
261 async fn streaming_passthrough_yields_input_unchanged() {
262 let original = Bytes::from_static(b"hello world");
263 let (out, manifest) = streaming_passthrough(make_blob(original.clone()))
264 .await
265 .unwrap();
266 assert_eq!(out, original);
267 assert_eq!(manifest.codec, CodecKind::Passthrough);
268 assert_eq!(manifest.original_size, original.len() as u64);
269 assert_eq!(manifest.compressed_size, original.len() as u64);
270 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
271 }
272}