1use std::io;
23use std::pin::Pin;
24use std::task::{Context, Poll};
25
26use async_compression::Level;
27use async_compression::tokio::bufread::ZstdDecoder;
28use async_compression::tokio::write::ZstdEncoder;
29use bytes::Bytes;
30use futures::{Stream, StreamExt};
31use s3s::StdError;
32use s3s::dto::StreamingBlob;
33use s3s::stream::{ByteStream, RemainingLength};
34use s4_codec::multipart::{FrameHeader, write_frame};
35use s4_codec::{ChunkManifest, CodecError, CodecKind, CodecRegistry};
36use std::sync::Arc;
37use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader};
38use tokio_util::io::{ReaderStream, StreamReader};
39
40pub fn blob_to_async_read(blob: StreamingBlob) -> impl AsyncRead + Unpin + Send + Sync {
47 let mapped = blob.map(|chunk| chunk.map_err(|e| io::Error::other(e.to_string())));
48 StreamReader::new(mapped)
49}
50
51pub fn async_read_to_blob<R: AsyncRead + Unpin + Send + Sync + 'static>(
53 reader: R,
54) -> StreamingBlob {
55 let stream = ReaderStream::new(reader).map(|res| res.map_err(|e| Box::new(e) as StdError));
56 StreamingBlob::new(StreamWrapper { inner: stream })
57}
58
59pin_project_lite::pin_project! {
60 struct StreamWrapper<S> { #[pin] inner: S }
63}
64
65impl<S> Stream for StreamWrapper<S>
66where
67 S: Stream<Item = Result<Bytes, StdError>>,
68{
69 type Item = Result<Bytes, StdError>;
70 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
71 self.project().inner.poll_next(cx)
72 }
73 fn size_hint(&self) -> (usize, Option<usize>) {
74 self.inner.size_hint()
75 }
76}
77
78impl<S> ByteStream for StreamWrapper<S>
79where
80 S: Stream<Item = Result<Bytes, StdError>> + Send + Sync,
81{
82 fn remaining_length(&self) -> RemainingLength {
83 RemainingLength::unknown()
85 }
86}
87
88pub fn cpu_zstd_decompress_stream(body: StreamingBlob) -> StreamingBlob {
98 let read = blob_to_async_read(body);
99 let mut decoder = ZstdDecoder::new(BufReader::new(read));
100 decoder.multiple_members(true);
101 async_read_to_blob(decoder)
102}
103
104pub fn supports_streaming_decompress(codec: CodecKind) -> bool {
106 matches!(
110 codec,
111 CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
112 )
113}
114
115pub fn supports_streaming_compress(codec: CodecKind) -> bool {
116 #[cfg(feature = "nvcomp-gpu")]
117 {
118 matches!(
119 codec,
120 CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
121 )
122 }
123 #[cfg(not(feature = "nvcomp-gpu"))]
124 {
125 matches!(codec, CodecKind::Passthrough | CodecKind::CpuZstd)
126 }
127}
128
129pub async fn streaming_compress_cpu_zstd(
139 body: StreamingBlob,
140 level: i32,
141) -> Result<(Bytes, ChunkManifest), CodecError> {
142 let mut read = blob_to_async_read(body);
143 let mut compressed_buf: Vec<u8> = Vec::with_capacity(256 * 1024);
144 let mut crc: u32 = 0;
145 let mut total_in: u64 = 0;
146 let mut in_buf = vec![0u8; 64 * 1024];
147
148 {
149 let mut encoder = ZstdEncoder::with_quality(&mut compressed_buf, Level::Precise(level));
150 loop {
151 let n = read.read(&mut in_buf).await.map_err(CodecError::Io)?;
152 if n == 0 {
153 break;
154 }
155 crc = crc32c::crc32c_append(crc, &in_buf[..n]);
156 total_in += n as u64;
157 encoder
158 .write_all(&in_buf[..n])
159 .await
160 .map_err(CodecError::Io)?;
161 }
162 encoder.shutdown().await.map_err(CodecError::Io)?;
163 }
164
165 let compressed_len = compressed_buf.len() as u64;
166 Ok((
167 Bytes::from(compressed_buf),
168 ChunkManifest {
169 codec: CodecKind::CpuZstd,
170 original_size: total_in,
171 compressed_size: compressed_len,
172 crc32c: crc,
173 },
174 ))
175}
176
177pub const DEFAULT_S4F2_CHUNK_SIZE: usize = 4 * 1024 * 1024;
184
185pub const DEFAULT_S4F2_INFLIGHT: usize = 3;
198
199pub async fn streaming_compress_to_frames(
220 body: StreamingBlob,
221 registry: Arc<CodecRegistry>,
222 codec_kind: CodecKind,
223 chunk_size: usize,
224) -> Result<(Bytes, ChunkManifest), CodecError> {
225 streaming_compress_to_frames_with(
226 body,
227 registry,
228 codec_kind,
229 chunk_size,
230 DEFAULT_S4F2_INFLIGHT,
231 )
232 .await
233}
234
235pub async fn streaming_compress_to_frames_with(
240 body: StreamingBlob,
241 registry: Arc<CodecRegistry>,
242 codec_kind: CodecKind,
243 chunk_size: usize,
244 inflight: usize,
245) -> Result<(Bytes, ChunkManifest), CodecError> {
246 use bytes::BytesMut;
247 use futures::StreamExt as _;
248 use futures::stream::FuturesOrdered;
249
250 let inflight = inflight.max(1);
251 let mut read = blob_to_async_read(body);
252 let mut framed = BytesMut::with_capacity(chunk_size);
253 let mut rolling_crc: u32 = 0;
254 let mut total_in: u64 = 0;
255 let mut chunk_buf = vec![0u8; chunk_size];
256
257 type InFlight = futures::future::BoxFuture<'static, Result<(FrameHeader, Bytes), CodecError>>;
261 let mut queue: FuturesOrdered<InFlight> = FuturesOrdered::new();
262 let mut eof = false;
263
264 loop {
265 while !eof && queue.len() < inflight {
267 let mut filled = 0;
268 while filled < chunk_size {
269 let n = read
270 .read(&mut chunk_buf[filled..])
271 .await
272 .map_err(CodecError::Io)?;
273 if n == 0 {
274 break;
275 }
276 filled += n;
277 }
278 if filled == 0 {
279 eof = true;
280 break;
281 }
282
283 let chunk_slice = &chunk_buf[..filled];
284 let chunk_crc = crc32c::crc32c(chunk_slice);
285 rolling_crc = crc32c::crc32c_append(rolling_crc, chunk_slice);
286 total_in += filled as u64;
287
288 let header = FrameHeader {
289 codec: codec_kind,
290 original_size: filled as u64,
291 compressed_size: 0, crc32c: chunk_crc,
293 };
294 let original_chunk = Bytes::copy_from_slice(chunk_slice);
295 let registry = Arc::clone(®istry);
296 queue.push_back(Box::pin(async move {
297 let (compressed_chunk, _per_chunk_manifest) =
298 registry.compress(original_chunk, codec_kind).await?;
299 let mut header = header;
300 header.compressed_size = compressed_chunk.len() as u64;
301 Ok::<_, CodecError>((header, compressed_chunk))
302 }));
303 }
304
305 match queue.next().await {
307 Some(Ok((header, compressed_chunk))) => {
308 write_frame(&mut framed, header, &compressed_chunk);
309 }
310 Some(Err(e)) => return Err(e),
311 None => break,
312 }
313 }
314
315 let total_framed = framed.len() as u64;
316 Ok((
317 framed.freeze(),
318 ChunkManifest {
319 codec: codec_kind,
320 original_size: total_in,
321 compressed_size: total_framed,
322 crc32c: rolling_crc,
323 },
324 ))
325}
326
327pub async fn streaming_passthrough(
329 body: StreamingBlob,
330) -> Result<(Bytes, ChunkManifest), CodecError> {
331 let mut read = blob_to_async_read(body);
332 let mut buf: Vec<u8> = Vec::with_capacity(256 * 1024);
333 let mut crc: u32 = 0;
334 let mut total: u64 = 0;
335 let mut chunk = vec![0u8; 64 * 1024];
336 loop {
337 let n = read.read(&mut chunk).await.map_err(CodecError::Io)?;
338 if n == 0 {
339 break;
340 }
341 crc = crc32c::crc32c_append(crc, &chunk[..n]);
342 total += n as u64;
343 buf.extend_from_slice(&chunk[..n]);
344 }
345 let len = buf.len() as u64;
346 Ok((
347 Bytes::from(buf),
348 ChunkManifest {
349 codec: CodecKind::Passthrough,
350 original_size: total,
351 compressed_size: len,
352 crc32c: crc,
353 },
354 ))
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use bytes::BytesMut;
361 use futures::stream;
362 use futures::stream::StreamExt;
363
364 async fn collect(blob: StreamingBlob) -> Bytes {
365 let mut buf = BytesMut::new();
366 let mut s = blob;
367 while let Some(chunk) = s.next().await {
368 buf.extend_from_slice(&chunk.unwrap());
369 }
370 buf.freeze()
371 }
372
373 fn make_blob(b: Bytes) -> StreamingBlob {
374 let stream = stream::once(async move { Ok::<_, std::io::Error>(b) });
375 StreamingBlob::wrap(stream)
376 }
377
378 #[tokio::test]
379 async fn cpu_zstd_streaming_roundtrip_small() {
380 let original = Bytes::from("the quick brown fox jumps over the lazy dog. ".repeat(100));
381 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
382 let blob = make_blob(Bytes::from(compressed));
383 let out_blob = cpu_zstd_decompress_stream(blob);
384 let out = collect(out_blob).await;
385 assert_eq!(out, original);
386 }
387
388 #[tokio::test]
389 async fn cpu_zstd_streaming_handles_chunked_input() {
390 let original = Bytes::from(vec![b'x'; 1_000_000]);
391 let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
392 let mut chunks = Vec::new();
394 for chunk in compressed.chunks(1024) {
395 chunks.push(Ok::<_, std::io::Error>(Bytes::copy_from_slice(chunk)));
396 }
397 let in_stream = stream::iter(chunks);
398 let blob = StreamingBlob::wrap(in_stream);
399 let out_blob = cpu_zstd_decompress_stream(blob);
400 let out = collect(out_blob).await;
401 assert_eq!(out, original);
402 }
403
404 #[tokio::test]
405 async fn streaming_passes_through_for_passthrough() {
406 let original = Bytes::from_static(b"hello");
407 let blob = make_blob(original.clone());
408 let out_blob = async_read_to_blob(blob_to_async_read(blob));
409 let out = collect(out_blob).await;
410 assert_eq!(out, original);
411 }
412
413 #[tokio::test]
414 async fn streaming_compress_then_decompress_roundtrip() {
415 let original = Bytes::from(vec![b'q'; 200_000]);
416 let blob = make_blob(original.clone());
417 let (compressed, manifest) = streaming_compress_cpu_zstd(blob, 3).await.unwrap();
418 assert!(
419 compressed.len() < original.len() / 100,
420 "should be highly compressible"
421 );
422 assert_eq!(manifest.codec, CodecKind::CpuZstd);
423 assert_eq!(manifest.original_size, original.len() as u64);
424 assert_eq!(manifest.compressed_size, compressed.len() as u64);
425 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
427
428 let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
430 let out = collect(decompressed_blob).await;
431 assert_eq!(out, original);
432 }
433
434 #[tokio::test]
439 async fn concatenated_zstd_frames_are_a_single_valid_stream() {
440 let chunk_a = Bytes::from(vec![b'a'; 50_000]);
441 let chunk_b = Bytes::from(vec![b'b'; 50_000]);
442 let chunk_c = Bytes::from(vec![b'c'; 50_000]);
443
444 let frame_a = zstd::stream::encode_all(chunk_a.as_ref(), 3).unwrap();
445 let frame_b = zstd::stream::encode_all(chunk_b.as_ref(), 3).unwrap();
446 let frame_c = zstd::stream::encode_all(chunk_c.as_ref(), 3).unwrap();
447
448 let mut concatenated: Vec<u8> = Vec::new();
449 concatenated.extend_from_slice(&frame_a);
450 concatenated.extend_from_slice(&frame_b);
451 concatenated.extend_from_slice(&frame_c);
452
453 let expected: Vec<u8> = chunk_a
454 .iter()
455 .chain(chunk_b.iter())
456 .chain(chunk_c.iter())
457 .copied()
458 .collect();
459
460 let blob = make_blob(Bytes::from(concatenated));
461 let out_blob = cpu_zstd_decompress_stream(blob);
462 let out = collect(out_blob).await;
463 assert_eq!(out, Bytes::from(expected));
464 }
465
466 #[tokio::test]
471 async fn streaming_chunked_compress_pipeline_roundtrip() {
472 async fn streaming_compress_chunked_cpu_zstd(
476 body: StreamingBlob,
477 chunk_size: usize,
478 ) -> Result<(Bytes, ChunkManifest), CodecError> {
479 let mut read = blob_to_async_read(body);
480 let mut compressed_buf: Vec<u8> = Vec::with_capacity(chunk_size / 2);
481 let mut crc: u32 = 0;
482 let mut total_in: u64 = 0;
483 let mut chunk_buf = vec![0u8; chunk_size];
484 loop {
485 let mut filled = 0;
486 while filled < chunk_size {
487 let n = read
488 .read(&mut chunk_buf[filled..])
489 .await
490 .map_err(CodecError::Io)?;
491 if n == 0 {
492 break;
493 }
494 filled += n;
495 }
496 if filled == 0 {
497 break;
498 }
499 crc = crc32c::crc32c_append(crc, &chunk_buf[..filled]);
500 total_in += filled as u64;
501 let compressed_chunk =
502 zstd::stream::encode_all(&chunk_buf[..filled], 3).map_err(CodecError::Io)?;
503 compressed_buf.extend_from_slice(&compressed_chunk);
504 }
505 let compressed_len = compressed_buf.len() as u64;
506 Ok((
507 Bytes::from(compressed_buf),
508 ChunkManifest {
509 codec: CodecKind::CpuZstd,
510 original_size: total_in,
511 compressed_size: compressed_len,
512 crc32c: crc,
513 },
514 ))
515 }
516
517 let original = Bytes::from(
519 (0u32..65_536)
520 .flat_map(|n| n.to_le_bytes())
521 .collect::<Vec<u8>>(),
522 );
523 assert_eq!(original.len(), 262_144);
524
525 let blob = make_blob(original.clone());
526 let (compressed, manifest) = streaming_compress_chunked_cpu_zstd(blob, 32 * 1024)
527 .await
528 .unwrap();
529
530 assert_eq!(manifest.original_size, original.len() as u64);
531 assert_eq!(manifest.compressed_size, compressed.len() as u64);
532 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
533
534 let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
536 let out = collect(decompressed_blob).await;
537 assert_eq!(out, original);
538 }
539
540 #[tokio::test]
541 async fn streaming_passthrough_yields_input_unchanged() {
542 let original = Bytes::from_static(b"hello world");
543 let (out, manifest) = streaming_passthrough(make_blob(original.clone()))
544 .await
545 .unwrap();
546 assert_eq!(out, original);
547 assert_eq!(manifest.codec, CodecKind::Passthrough);
548 assert_eq!(manifest.original_size, original.len() as u64);
549 assert_eq!(manifest.compressed_size, original.len() as u64);
550 assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
551 }
552}