1use std::io;
23use std::pin::Pin;
24use std::task::{Context, Poll};
25
26use async_compression::Level;
27use async_compression::tokio::bufread::ZstdDecoder;
28use async_compression::tokio::write::ZstdEncoder;
29use bytes::Bytes;
30use futures::{Stream, StreamExt};
31use s3s::StdError;
32use s3s::dto::StreamingBlob;
33use s3s::stream::{ByteStream, RemainingLength};
34use s4_codec::multipart::{FrameHeader, write_frame};
35use s4_codec::{ChunkManifest, CodecError, CodecKind, CodecRegistry};
36use std::sync::Arc;
37use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader};
38use tokio_util::io::{ReaderStream, StreamReader};
39
40pub fn blob_to_async_read(blob: StreamingBlob) -> impl AsyncRead + Unpin + Send + Sync {
47 let mapped = blob.map(|chunk| chunk.map_err(|e| io::Error::other(e.to_string())));
48 StreamReader::new(mapped)
49}
50
51pub fn async_read_to_blob<R: AsyncRead + Unpin + Send + Sync + 'static>(
53 reader: R,
54) -> StreamingBlob {
55 let stream = ReaderStream::new(reader).map(|res| res.map_err(|e| Box::new(e) as StdError));
56 StreamingBlob::new(StreamWrapper { inner: stream })
57}
58
59pin_project_lite::pin_project! {
60 struct StreamWrapper<S> { #[pin] inner: S }
63}
64
65impl<S> Stream for StreamWrapper<S>
66where
67 S: Stream<Item = Result<Bytes, StdError>>,
68{
69 type Item = Result<Bytes, StdError>;
70 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71 self.project().inner.poll_next(cx)
72 }
73 fn size_hint(&self) -> (usize, Option<usize>) {
74 self.inner.size_hint()
75 }
76}
77
78impl<S> ByteStream for StreamWrapper<S>
79where
80 S: Stream<Item = Result<Bytes, StdError>> + Send + Sync,
81{
82 fn remaining_length(&self) -> RemainingLength {
83 RemainingLength::unknown()
85 }
86}
87
88pub fn cpu_zstd_decompress_stream(body: StreamingBlob) -> StreamingBlob {
98 let read = blob_to_async_read(body);
99 let mut decoder = ZstdDecoder::new(BufReader::new(read));
100 decoder.multiple_members(true);
101 async_read_to_blob(decoder)
102}
103
104pub fn supports_streaming_decompress(codec: CodecKind) -> bool {
106 matches!(
110 codec,
111 CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
112 )
113}
114
115pub fn supports_streaming_compress(codec: CodecKind) -> bool {
116 #[cfg(feature = "nvcomp-gpu")]
117 {
118 matches!(
119 codec,
120 CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
121 )
122 }
123 #[cfg(not(feature = "nvcomp-gpu"))]
124 {
125 matches!(codec, CodecKind::Passthrough | CodecKind::CpuZstd)
126 }
127}
128
129pub async fn streaming_compress_cpu_zstd(
139 body: StreamingBlob,
140 level: i32,
141) -> Result<(Bytes, ChunkManifest), CodecError> {
142 let mut read = blob_to_async_read(body);
143 let mut compressed_buf: Vec<u8> = Vec::with_capacity(256 * 1024);
144 let mut crc: u32 = 0;
145 let mut total_in: u64 = 0;
146 let mut in_buf = vec![0u8; 64 * 1024];
147
148 {
149 let mut encoder = ZstdEncoder::with_quality(&mut compressed_buf, Level::Precise(level));
150 loop {
151 let n = read.read(&mut in_buf).await.map_err(CodecError::Io)?;
152 if n == 0 {
153 break;
154 }
155 crc = crc32c::crc32c_append(crc, &in_buf[..n]);
156 total_in += n as u64;
157 encoder
158 .write_all(&in_buf[..n])
159 .await
160 .map_err(CodecError::Io)?;
161 }
162 encoder.shutdown().await.map_err(CodecError::Io)?;
163 }
164
165 let compressed_len = compressed_buf.len() as u64;
166 Ok((
167 Bytes::from(compressed_buf),
168 ChunkManifest {
169 codec: CodecKind::CpuZstd,
170 original_size: total_in,
171 compressed_size: compressed_len,
172 crc32c: crc,
173 },
174 ))
175}
176
177pub const DEFAULT_S4F2_CHUNK_SIZE: usize = 4 * 1024 * 1024;
184
185pub async fn streaming_compress_to_frames(
203 body: StreamingBlob,
204 registry: Arc<CodecRegistry>,
205 codec_kind: CodecKind,
206 chunk_size: usize,
207) -> Result<(Bytes, ChunkManifest), CodecError> {
208 use bytes::BytesMut;
209 let mut read = blob_to_async_read(body);
210 let mut framed = BytesMut::with_capacity(chunk_size);
211 let mut rolling_crc: u32 = 0;
212 let mut total_in: u64 = 0;
213 let mut chunk_buf = vec![0u8; chunk_size];
214
215 loop {
216 let mut filled = 0;
217 while filled < chunk_size {
218 let n = read
219 .read(&mut chunk_buf[filled..])
220 .await
221 .map_err(CodecError::Io)?;
222 if n == 0 {
223 break;
224 }
225 filled += n;
226 }
227 if filled == 0 {
228 break;
229 }
230
231 let chunk_slice = &chunk_buf[..filled];
232 let chunk_crc = crc32c::crc32c(chunk_slice);
233 rolling_crc = crc32c::crc32c_append(rolling_crc, chunk_slice);
234 total_in += filled as u64;
235
236 let original_chunk = Bytes::copy_from_slice(chunk_slice);
237 let (compressed_chunk, _per_chunk_manifest) =
238 registry.compress(original_chunk, codec_kind).await?;
239
240 let header = FrameHeader {
241 codec: codec_kind,
242 original_size: filled as u64,
243 compressed_size: compressed_chunk.len() as u64,
244 crc32c: chunk_crc,
245 };
246 write_frame(&mut framed, header, &compressed_chunk);
247 }
248
249 let total_framed = framed.len() as u64;
250 Ok((
251 framed.freeze(),
252 ChunkManifest {
253 codec: codec_kind,
254 original_size: total_in,
255 compressed_size: total_framed,
256 crc32c: rolling_crc,
257 },
258 ))
259}
260
261pub async fn streaming_passthrough(
263 body: StreamingBlob,
264) -> Result<(Bytes, ChunkManifest), CodecError> {
265 let mut read = blob_to_async_read(body);
266 let mut buf: Vec<u8> = Vec::with_capacity(256 * 1024);
267 let mut crc: u32 = 0;
268 let mut total: u64 = 0;
269 let mut chunk = vec![0u8; 64 * 1024];
270 loop {
271 let n = read.read(&mut chunk).await.map_err(CodecError::Io)?;
272 if n == 0 {
273 break;
274 }
275 crc = crc32c::crc32c_append(crc, &chunk[..n]);
276 total += n as u64;
277 buf.extend_from_slice(&chunk[..n]);
278 }
279 let len = buf.len() as u64;
280 Ok((
281 Bytes::from(buf),
282 ChunkManifest {
283 codec: CodecKind::Passthrough,
284 original_size: total,
285 compressed_size: len,
286 crc32c: crc,
287 },
288 ))
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use bytes::BytesMut;
295 use futures::stream;
296 use futures::stream::StreamExt;
297
298 async fn collect(blob: StreamingBlob) -> Bytes {
299 let mut buf = BytesMut::new();
300 let mut s = blob;
301 while let Some(chunk) = s.next().await {
302 buf.extend_from_slice(&chunk.unwrap());
303 }
304 buf.freeze()
305 }
306
307 fn make_blob(b: Bytes) -> StreamingBlob {
308 let stream = stream::once(async move { Ok::<_, std::io::Error>(b) });
309 StreamingBlob::wrap(stream)
310 }
311
312 #[tokio::test]
313 async fn cpu_zstd_streaming_roundtrip_small() {
314 let original = Bytes::from("the quick brown fox jumps over the lazy dog. ".repeat(100));
315 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
316 let blob = make_blob(Bytes::from(compressed));
317 let out_blob = cpu_zstd_decompress_stream(blob);
318 let out = collect(out_blob).await;
319 assert_eq!(out, original);
320 }
321
322 #[tokio::test]
323 async fn cpu_zstd_streaming_handles_chunked_input() {
324 let original = Bytes::from(vec![b'x'; 1_000_000]);
325 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
326 let mut chunks = Vec::new();
328 for chunk in compressed.chunks(1024) {
329 chunks.push(Ok::<_, std::io::Error>(Bytes::copy_from_slice(chunk)));
330 }
331 let in_stream = stream::iter(chunks);
332 let blob = StreamingBlob::wrap(in_stream);
333 let out_blob = cpu_zstd_decompress_stream(blob);
334 let out = collect(out_blob).await;
335 assert_eq!(out, original);
336 }
337
338 #[tokio::test]
339 async fn streaming_passes_through_for_passthrough() {
340 let original = Bytes::from_static(b"hello");
341 let blob = make_blob(original.clone());
342 let out_blob = async_read_to_blob(blob_to_async_read(blob));
343 let out = collect(out_blob).await;
344 assert_eq!(out, original);
345 }
346
347 #[tokio::test]
348 async fn streaming_compress_then_decompress_roundtrip() {
349 let original = Bytes::from(vec![b'q'; 200_000]);
350 let blob = make_blob(original.clone());
351 let (compressed, manifest) = streaming_compress_cpu_zstd(blob, 3).await.unwrap();
352 assert!(
353 compressed.len() < original.len() / 100,
354 "should be highly compressible"
355 );
356 assert_eq!(manifest.codec, CodecKind::CpuZstd);
357 assert_eq!(manifest.original_size, original.len() as u64);
358 assert_eq!(manifest.compressed_size, compressed.len() as u64);
359 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
361
362 let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
364 let out = collect(decompressed_blob).await;
365 assert_eq!(out, original);
366 }
367
368 #[tokio::test]
373 async fn concatenated_zstd_frames_are_a_single_valid_stream() {
374 let chunk_a = Bytes::from(vec![b'a'; 50_000]);
375 let chunk_b = Bytes::from(vec![b'b'; 50_000]);
376 let chunk_c = Bytes::from(vec![b'c'; 50_000]);
377
378 let frame_a = zstd::stream::encode_all(chunk_a.as_ref(), 3).unwrap();
379 let frame_b = zstd::stream::encode_all(chunk_b.as_ref(), 3).unwrap();
380 let frame_c = zstd::stream::encode_all(chunk_c.as_ref(), 3).unwrap();
381
382 let mut concatenated: Vec<u8> = Vec::new();
383 concatenated.extend_from_slice(&frame_a);
384 concatenated.extend_from_slice(&frame_b);
385 concatenated.extend_from_slice(&frame_c);
386
387 let expected: Vec<u8> = chunk_a
388 .iter()
389 .chain(chunk_b.iter())
390 .chain(chunk_c.iter())
391 .copied()
392 .collect();
393
394 let blob = make_blob(Bytes::from(concatenated));
395 let out_blob = cpu_zstd_decompress_stream(blob);
396 let out = collect(out_blob).await;
397 assert_eq!(out, Bytes::from(expected));
398 }
399
400 #[tokio::test]
405 async fn streaming_chunked_compress_pipeline_roundtrip() {
406 async fn streaming_compress_chunked_cpu_zstd(
410 body: StreamingBlob,
411 chunk_size: usize,
412 ) -> Result<(Bytes, ChunkManifest), CodecError> {
413 let mut read = blob_to_async_read(body);
414 let mut compressed_buf: Vec<u8> = Vec::with_capacity(chunk_size / 2);
415 let mut crc: u32 = 0;
416 let mut total_in: u64 = 0;
417 let mut chunk_buf = vec![0u8; chunk_size];
418 loop {
419 let mut filled = 0;
420 while filled < chunk_size {
421 let n = read
422 .read(&mut chunk_buf[filled..])
423 .await
424 .map_err(CodecError::Io)?;
425 if n == 0 {
426 break;
427 }
428 filled += n;
429 }
430 if filled == 0 {
431 break;
432 }
433 crc = crc32c::crc32c_append(crc, &chunk_buf[..filled]);
434 total_in += filled as u64;
435 let compressed_chunk =
436 zstd::stream::encode_all(&chunk_buf[..filled], 3).map_err(CodecError::Io)?;
437 compressed_buf.extend_from_slice(&compressed_chunk);
438 }
439 let compressed_len = compressed_buf.len() as u64;
440 Ok((
441 Bytes::from(compressed_buf),
442 ChunkManifest {
443 codec: CodecKind::CpuZstd,
444 original_size: total_in,
445 compressed_size: compressed_len,
446 crc32c: crc,
447 },
448 ))
449 }
450
451 let original = Bytes::from(
453 (0u32..65_536)
454 .flat_map(|n| n.to_le_bytes())
455 .collect::<Vec<u8>>(),
456 );
457 assert_eq!(original.len(), 262_144);
458
459 let blob = make_blob(original.clone());
460 let (compressed, manifest) = streaming_compress_chunked_cpu_zstd(blob, 32 * 1024)
461 .await
462 .unwrap();
463
464 assert_eq!(manifest.original_size, original.len() as u64);
465 assert_eq!(manifest.compressed_size, compressed.len() as u64);
466 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
467
468 let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
470 let out = collect(decompressed_blob).await;
471 assert_eq!(out, original);
472 }
473
474 #[tokio::test]
475 async fn streaming_passthrough_yields_input_unchanged() {
476 let original = Bytes::from_static(b"hello world");
477 let (out, manifest) = streaming_passthrough(make_blob(original.clone()))
478 .await
479 .unwrap();
480 assert_eq!(out, original);
481 assert_eq!(manifest.codec, CodecKind::Passthrough);
482 assert_eq!(manifest.original_size, original.len() as u64);
483 assert_eq!(manifest.compressed_size, original.len() as u64);
484 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
485 }
486}