use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use async_compression::Level;
use async_compression::tokio::bufread::ZstdDecoder;
use async_compression::tokio::write::ZstdEncoder;
use bytes::Bytes;
use futures::{Stream, StreamExt};
use s3s::StdError;
use s3s::dto::StreamingBlob;
use s3s::stream::{ByteStream, RemainingLength};
use s4_codec::multipart::{FrameHeader, write_frame};
use s4_codec::{ChunkManifest, CodecError, CodecKind, CodecRegistry};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, ReadBuf};
use tokio_util::io::{ReaderStream, StreamReader};
pub fn blob_to_async_read(blob: StreamingBlob) -> impl AsyncRead + Unpin + Send + Sync + 'static {
let mapped = blob.map(|chunk| chunk.map_err(|e| io::Error::other(e.to_string())));
StreamReader::new(mapped)
}
pub fn async_read_to_blob<R: AsyncRead + Unpin + Send + Sync + 'static>(
reader: R,
) -> StreamingBlob {
let stream = ReaderStream::new(reader).map(|res| res.map_err(|e| Box::new(e) as StdError));
StreamingBlob::new(StreamWrapper { inner: stream })
}
pin_project_lite::pin_project! {
struct StreamWrapper<S> { #[pin] inner: S }
}
impl<S> Stream for StreamWrapper<S>
where
S: Stream<Item = Result<Bytes, StdError>>,
{
type Item = Result<Bytes, StdError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl<S> ByteStream for StreamWrapper<S>
where
S: Stream<Item = Result<Bytes, StdError>> + Send + Sync,
{
fn remaining_length(&self) -> RemainingLength {
RemainingLength::unknown()
}
}
pub fn cpu_zstd_decompress_stream(body: StreamingBlob) -> StreamingBlob {
let read = blob_to_async_read(body);
let mut decoder = ZstdDecoder::new(BufReader::new(read));
decoder.multiple_members(true);
async_read_to_blob(decoder)
}
pub struct Crc32cVerifyingReader<R> {
inner: R,
expected_crc: u32,
expected_size: u64,
rolling_crc: u32,
bytes_read: u64,
failed: bool,
}
impl<R> Crc32cVerifyingReader<R> {
pub fn new(inner: R, expected_crc: u32, expected_size: u64) -> Self {
Self {
inner,
expected_crc,
expected_size,
rolling_crc: 0,
bytes_read: 0,
failed: false,
}
}
#[cfg(test)]
pub fn rolling_crc(&self) -> u32 {
self.rolling_crc
}
#[cfg(test)]
pub fn bytes_read(&self) -> u64 {
self.bytes_read
}
}
impl<R> AsyncRead for Crc32cVerifyingReader<R>
where
R: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if self.failed {
return Poll::Ready(Ok(()));
}
let pre_filled = buf.filled().len();
match Pin::new(&mut self.inner).poll_read(cx, buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Ready(Ok(())) => {
let new_filled = buf.filled().len();
if new_filled > pre_filled {
let chunk = &buf.filled()[pre_filled..new_filled];
self.rolling_crc = crc32c::crc32c_append(self.rolling_crc, chunk);
self.bytes_read = self.bytes_read.saturating_add(chunk.len() as u64);
Poll::Ready(Ok(()))
} else {
if self.bytes_read != self.expected_size {
self.failed = true;
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"S4 streaming GET size mismatch: \
expected {} bytes, got {}",
self.expected_size, self.bytes_read
),
)));
}
if self.rolling_crc != self.expected_crc {
self.failed = true;
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"S4 streaming GET crc32c mismatch: \
expected {:#010x}, got {:#010x}",
self.expected_crc, self.rolling_crc
),
)));
}
Poll::Ready(Ok(()))
}
}
}
}
}
pub fn supports_streaming_decompress(codec: CodecKind) -> bool {
matches!(
codec,
CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
)
}
pub fn supports_streaming_compress(codec: CodecKind) -> bool {
#[cfg(feature = "nvcomp-gpu")]
{
matches!(
codec,
CodecKind::Passthrough | CodecKind::CpuZstd | CodecKind::NvcompZstd
)
}
#[cfg(not(feature = "nvcomp-gpu"))]
{
matches!(codec, CodecKind::Passthrough | CodecKind::CpuZstd)
}
}
pub async fn streaming_compress_cpu_zstd(
body: StreamingBlob,
level: i32,
) -> Result<(Bytes, ChunkManifest), CodecError> {
let mut read = blob_to_async_read(body);
let mut compressed_buf: Vec<u8> = Vec::with_capacity(256 * 1024);
let mut crc: u32 = 0;
let mut total_in: u64 = 0;
let mut in_buf = vec![0u8; 64 * 1024];
{
let mut encoder = ZstdEncoder::with_quality(&mut compressed_buf, Level::Precise(level));
loop {
let n = read.read(&mut in_buf).await.map_err(CodecError::Io)?;
if n == 0 {
break;
}
crc = crc32c::crc32c_append(crc, &in_buf[..n]);
total_in += n as u64;
encoder
.write_all(&in_buf[..n])
.await
.map_err(CodecError::Io)?;
}
encoder.shutdown().await.map_err(CodecError::Io)?;
}
let compressed_len = compressed_buf.len() as u64;
Ok((
Bytes::from(compressed_buf),
ChunkManifest {
codec: CodecKind::CpuZstd,
original_size: total_in,
compressed_size: compressed_len,
crc32c: crc,
},
))
}
pub const DEFAULT_S4F2_CHUNK_SIZE: usize = 4 * 1024 * 1024;
pub fn pick_chunk_size(content_length: Option<u64>) -> usize {
match content_length {
None => DEFAULT_S4F2_CHUNK_SIZE,
Some(len) if len <= 1024 * 1024 => 1024 * 1024,
Some(len) if len <= 100 * 1024 * 1024 => DEFAULT_S4F2_CHUNK_SIZE,
Some(_) => 16 * 1024 * 1024,
}
}
pub const DEFAULT_S4F2_INFLIGHT: usize = 3;
pub async fn streaming_compress_to_frames(
body: StreamingBlob,
registry: Arc<CodecRegistry>,
codec_kind: CodecKind,
chunk_size: usize,
expected_size: Option<u64>,
) -> Result<(Bytes, ChunkManifest), CodecError> {
streaming_compress_to_frames_with(
body,
registry,
codec_kind,
chunk_size,
DEFAULT_S4F2_INFLIGHT,
expected_size,
)
.await
}
pub async fn streaming_compress_to_frames_with(
body: StreamingBlob,
registry: Arc<CodecRegistry>,
codec_kind: CodecKind,
chunk_size: usize,
inflight: usize,
expected_size: Option<u64>,
) -> Result<(Bytes, ChunkManifest), CodecError> {
use bytes::BytesMut;
use futures::StreamExt as _;
use futures::stream::FuturesOrdered;
let inflight = inflight.max(1);
let mut read = blob_to_async_read(body);
let mut framed = BytesMut::with_capacity(chunk_size);
let mut rolling_crc: u32 = 0;
let mut total_in: u64 = 0;
let mut chunk_buf = vec![0u8; chunk_size];
type InFlight = futures::future::BoxFuture<'static, Result<(FrameHeader, Bytes), CodecError>>;
let mut queue: FuturesOrdered<InFlight> = FuturesOrdered::new();
let mut eof = false;
loop {
while !eof && queue.len() < inflight {
let mut filled = 0;
while filled < chunk_size {
let n = read
.read(&mut chunk_buf[filled..])
.await
.map_err(CodecError::Io)?;
if n == 0 {
break;
}
filled += n;
}
if filled == 0 {
eof = true;
break;
}
let chunk_slice = &chunk_buf[..filled];
let chunk_crc = crc32c::crc32c(chunk_slice);
rolling_crc = crc32c::crc32c_append(rolling_crc, chunk_slice);
total_in += filled as u64;
let header = FrameHeader {
codec: codec_kind,
original_size: filled as u64,
compressed_size: 0, crc32c: chunk_crc,
};
let original_chunk = Bytes::copy_from_slice(chunk_slice);
let registry = Arc::clone(®istry);
queue.push_back(Box::pin(async move {
let (compressed_chunk, _per_chunk_manifest) =
registry.compress(original_chunk, codec_kind).await?;
let mut header = header;
header.compressed_size = compressed_chunk.len() as u64;
Ok::<_, CodecError>((header, compressed_chunk))
}));
}
match queue.next().await {
Some(Ok((header, compressed_chunk))) => {
write_frame(&mut framed, header, &compressed_chunk);
}
Some(Err(e)) => return Err(e),
None => break,
}
}
if let Some(expected) = expected_size
&& total_in < expected
{
return Err(CodecError::TruncatedStream {
expected,
got: total_in,
});
}
let total_framed = framed.len() as u64;
Ok((
framed.freeze(),
ChunkManifest {
codec: codec_kind,
original_size: total_in,
compressed_size: total_framed,
crc32c: rolling_crc,
},
))
}
pub async fn streaming_passthrough(
body: StreamingBlob,
) -> Result<(Bytes, ChunkManifest), CodecError> {
let mut read = blob_to_async_read(body);
let mut buf: Vec<u8> = Vec::with_capacity(256 * 1024);
let mut crc: u32 = 0;
let mut total: u64 = 0;
let mut chunk = vec![0u8; 64 * 1024];
loop {
let n = read.read(&mut chunk).await.map_err(CodecError::Io)?;
if n == 0 {
break;
}
crc = crc32c::crc32c_append(crc, &chunk[..n]);
total += n as u64;
buf.extend_from_slice(&chunk[..n]);
}
let len = buf.len() as u64;
Ok((
Bytes::from(buf),
ChunkManifest {
codec: CodecKind::Passthrough,
original_size: total,
compressed_size: len,
crc32c: crc,
},
))
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::BytesMut;
use futures::stream;
use futures::stream::StreamExt;
#[test]
fn pick_chunk_size_thresholds() {
assert_eq!(pick_chunk_size(None), DEFAULT_S4F2_CHUNK_SIZE);
assert_eq!(pick_chunk_size(Some(0)), 1024 * 1024);
assert_eq!(pick_chunk_size(Some(64 * 1024)), 1024 * 1024);
assert_eq!(pick_chunk_size(Some(1024 * 1024)), 1024 * 1024);
assert_eq!(
pick_chunk_size(Some(1024 * 1024 + 1)),
DEFAULT_S4F2_CHUNK_SIZE
);
assert_eq!(
pick_chunk_size(Some(50 * 1024 * 1024)),
DEFAULT_S4F2_CHUNK_SIZE
);
assert_eq!(
pick_chunk_size(Some(100 * 1024 * 1024)),
DEFAULT_S4F2_CHUNK_SIZE
);
assert_eq!(
pick_chunk_size(Some(100 * 1024 * 1024 + 1)),
16 * 1024 * 1024
);
assert_eq!(
pick_chunk_size(Some(10 * 1024 * 1024 * 1024)),
16 * 1024 * 1024
);
}
async fn collect(blob: StreamingBlob) -> Bytes {
let mut buf = BytesMut::new();
let mut s = blob;
while let Some(chunk) = s.next().await {
buf.extend_from_slice(&chunk.unwrap());
}
buf.freeze()
}
fn make_blob(b: Bytes) -> StreamingBlob {
let stream = stream::once(async move { Ok::<_, std::io::Error>(b) });
StreamingBlob::wrap(stream)
}
#[tokio::test]
async fn cpu_zstd_streaming_roundtrip_small() {
let original = Bytes::from("the quick brown fox jumps over the lazy dog. ".repeat(100));
let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
let blob = make_blob(Bytes::from(compressed));
let out_blob = cpu_zstd_decompress_stream(blob);
let out = collect(out_blob).await;
assert_eq!(out, original);
}
#[tokio::test]
async fn cpu_zstd_streaming_handles_chunked_input() {
let original = Bytes::from(vec![b'x'; 1_000_000]);
let compressed = zstd::stream::encode_all(original.as_ref(), 3).unwrap();
let mut chunks = Vec::new();
for chunk in compressed.chunks(1024) {
chunks.push(Ok::<_, std::io::Error>(Bytes::copy_from_slice(chunk)));
}
let in_stream = stream::iter(chunks);
let blob = StreamingBlob::wrap(in_stream);
let out_blob = cpu_zstd_decompress_stream(blob);
let out = collect(out_blob).await;
assert_eq!(out, original);
}
#[tokio::test]
async fn streaming_passes_through_for_passthrough() {
let original = Bytes::from_static(b"hello");
let blob = make_blob(original.clone());
let out_blob = async_read_to_blob(blob_to_async_read(blob));
let out = collect(out_blob).await;
assert_eq!(out, original);
}
#[tokio::test]
async fn streaming_compress_then_decompress_roundtrip() {
let original = Bytes::from(vec![b'q'; 200_000]);
let blob = make_blob(original.clone());
let (compressed, manifest) = streaming_compress_cpu_zstd(blob, 3).await.unwrap();
assert!(
compressed.len() < original.len() / 100,
"should be highly compressible"
);
assert_eq!(manifest.codec, CodecKind::CpuZstd);
assert_eq!(manifest.original_size, original.len() as u64);
assert_eq!(manifest.compressed_size, compressed.len() as u64);
assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
let out = collect(decompressed_blob).await;
assert_eq!(out, original);
}
#[tokio::test]
async fn concatenated_zstd_frames_are_a_single_valid_stream() {
let chunk_a = Bytes::from(vec![b'a'; 50_000]);
let chunk_b = Bytes::from(vec![b'b'; 50_000]);
let chunk_c = Bytes::from(vec![b'c'; 50_000]);
let frame_a = zstd::stream::encode_all(chunk_a.as_ref(), 3).unwrap();
let frame_b = zstd::stream::encode_all(chunk_b.as_ref(), 3).unwrap();
let frame_c = zstd::stream::encode_all(chunk_c.as_ref(), 3).unwrap();
let mut concatenated: Vec<u8> = Vec::new();
concatenated.extend_from_slice(&frame_a);
concatenated.extend_from_slice(&frame_b);
concatenated.extend_from_slice(&frame_c);
let expected: Vec<u8> = chunk_a
.iter()
.chain(chunk_b.iter())
.chain(chunk_c.iter())
.copied()
.collect();
let blob = make_blob(Bytes::from(concatenated));
let out_blob = cpu_zstd_decompress_stream(blob);
let out = collect(out_blob).await;
assert_eq!(out, Bytes::from(expected));
}
#[tokio::test]
async fn streaming_chunked_compress_pipeline_roundtrip() {
async fn streaming_compress_chunked_cpu_zstd(
body: StreamingBlob,
chunk_size: usize,
) -> Result<(Bytes, ChunkManifest), CodecError> {
let mut read = blob_to_async_read(body);
let mut compressed_buf: Vec<u8> = Vec::with_capacity(chunk_size / 2);
let mut crc: u32 = 0;
let mut total_in: u64 = 0;
let mut chunk_buf = vec![0u8; chunk_size];
loop {
let mut filled = 0;
while filled < chunk_size {
let n = read
.read(&mut chunk_buf[filled..])
.await
.map_err(CodecError::Io)?;
if n == 0 {
break;
}
filled += n;
}
if filled == 0 {
break;
}
crc = crc32c::crc32c_append(crc, &chunk_buf[..filled]);
total_in += filled as u64;
let compressed_chunk =
zstd::stream::encode_all(&chunk_buf[..filled], 3).map_err(CodecError::Io)?;
compressed_buf.extend_from_slice(&compressed_chunk);
}
let compressed_len = compressed_buf.len() as u64;
Ok((
Bytes::from(compressed_buf),
ChunkManifest {
codec: CodecKind::CpuZstd,
original_size: total_in,
compressed_size: compressed_len,
crc32c: crc,
},
))
}
let original = Bytes::from(
(0u32..65_536)
.flat_map(|n| n.to_le_bytes())
.collect::<Vec<u8>>(),
);
assert_eq!(original.len(), 262_144);
let blob = make_blob(original.clone());
let (compressed, manifest) = streaming_compress_chunked_cpu_zstd(blob, 32 * 1024)
.await
.unwrap();
assert_eq!(manifest.original_size, original.len() as u64);
assert_eq!(manifest.compressed_size, compressed.len() as u64);
assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
let decompressed_blob = cpu_zstd_decompress_stream(make_blob(compressed));
let out = collect(decompressed_blob).await;
assert_eq!(out, original);
}
#[tokio::test]
async fn streaming_passthrough_yields_input_unchanged() {
let original = Bytes::from_static(b"hello world");
let (out, manifest) = streaming_passthrough(make_blob(original.clone()))
.await
.unwrap();
assert_eq!(out, original);
assert_eq!(manifest.codec, CodecKind::Passthrough);
assert_eq!(manifest.original_size, original.len() as u64);
assert_eq!(manifest.compressed_size, original.len() as u64);
assert_eq!(manifest.crc32c, crc32c::crc32c(&original));
}
#[tokio::test]
async fn crc32c_verifying_reader_passes_correct_crc() {
use tokio::io::AsyncReadExt as _;
let original = Bytes::from(vec![0xa3u8; 17_000]);
let crc = crc32c::crc32c(&original);
let inner = blob_to_async_read(make_blob(original.clone()));
let mut verifier = Crc32cVerifyingReader::new(inner, crc, original.len() as u64);
let mut out = Vec::new();
verifier
.read_to_end(&mut out)
.await
.expect("clean stream must read cleanly");
assert_eq!(out, original.as_ref());
assert_eq!(verifier.rolling_crc(), crc);
assert_eq!(verifier.bytes_read(), original.len() as u64);
}
#[tokio::test]
async fn crc32c_verifying_reader_detects_corruption() {
use tokio::io::AsyncReadExt as _;
let original = Bytes::from_static(b"clean payload bytes");
let real_crc = crc32c::crc32c(&original);
let bogus_expected_crc = real_crc.wrapping_add(1);
let inner = blob_to_async_read(make_blob(original.clone()));
let mut verifier =
Crc32cVerifyingReader::new(inner, bogus_expected_crc, original.len() as u64);
let mut out = Vec::new();
let err = verifier
.read_to_end(&mut out)
.await
.expect_err("CRC mismatch must surface as io::Error");
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
let msg = err.to_string();
assert!(
msg.contains("crc32c mismatch"),
"error must mention CRC mismatch, got `{msg}`"
);
assert_eq!(out, original.as_ref());
}
#[tokio::test]
async fn streaming_compress_truncated_input_returns_truncated_stream_error() {
use s4_codec::cpu_zstd::CpuZstd;
let registry =
Arc::new(CodecRegistry::new(CodecKind::CpuZstd).with(Arc::new(CpuZstd::default())));
let actual = Bytes::from(vec![b'z'; 4096]);
let advertised: u64 = 16 * 1024;
let blob = make_blob(actual.clone());
let err = streaming_compress_to_frames(
blob,
registry,
CodecKind::CpuZstd,
1024,
Some(advertised),
)
.await
.expect_err("truncated stream must error");
match err {
CodecError::TruncatedStream { expected, got } => {
assert_eq!(expected, advertised);
assert_eq!(got, actual.len() as u64);
}
other => panic!("expected TruncatedStream, got {other:?}"),
}
}
}