#![expect(
clippy::allow_attributes,
reason = "macro-generated `#[allow]` attributes whose underlying lints fire only for some expansions"
)]
#![expect(
clippy::unreachable,
reason = "SENTINEL_ERROR_CODE is only stored alongside an underlying body error, never observed at the unreachable branch"
)]
use crate::body::{Frame, SizeHint, StreamingBody};
use pin_project_lite::pin_project;
use rama_core::bytes::{Buf, Bytes, BytesMut};
use rama_core::error::BoxError;
use rama_core::futures::Stream;
use rama_core::futures::ready;
use rama_core::stream::io::StreamReader;
use std::io::ErrorKind;
use std::{
io,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::AsyncRead;
macro_rules! compressed_body_poll_frame {
($self:expr, $cx:expr) => {
match $self.project().inner.project() {
BodyInnerProj::Gzip { inner } => inner.poll_frame($cx),
BodyInnerProj::Deflate { inner } => inner.poll_frame($cx),
BodyInnerProj::Brotli { inner } => inner.poll_frame($cx),
BodyInnerProj::Zstd { inner } => inner.poll_frame($cx),
BodyInnerProj::Identity { inner } => match ready!(inner.poll_frame($cx)) {
Some(Ok(frame)) => {
let frame = frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()));
Poll::Ready(Some(Ok(frame)))
}
Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
None => Poll::Ready(None),
},
}
};
}
pub(crate) use compressed_body_poll_frame;
macro_rules! impl_decorate_async_read {
($codec:ident: |$input:pat_param, $quality:pat_param| $apply:block) => {
impl<B> DecorateAsyncRead for $codec<B>
where
B: StreamingBody,
{
type Input = AsyncReadBody<B>;
type Output = $codec<Self::Input>;
fn apply($input: Self::Input, $quality: CompressionLevel) -> Self::Output $apply
fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> {
pinned.get_pin_mut()
}
}
};
}
pub(crate) use impl_decorate_async_read;
pub(crate) type AsyncReadBody<B> = StreamReader<
StreamErrorIntoIoError<BodyIntoStream<B>, <B as StreamingBody>::Error>,
<B as StreamingBody>::Data,
>;
pub(crate) trait DecorateAsyncRead {
type Input: AsyncRead;
type Output: AsyncRead;
fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output;
fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input>;
}
pin_project! {
pub(crate) struct WrapBody<M: DecorateAsyncRead> {
#[pin]
pub read: M::Output,
buf: BytesMut,
read_all_data: bool,
tolerate_decode_errors: bool,
}
}
impl<M: DecorateAsyncRead> WrapBody<M> {
const INTERNAL_BUF_CAPACITY: usize = 8096;
}
impl<M: DecorateAsyncRead> WrapBody<M> {
#[allow(dead_code)]
pub(crate) fn new<B>(body: B, quality: CompressionLevel) -> Self
where
B: StreamingBody,
M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
{
let stream = BodyIntoStream::new(body);
let stream = StreamErrorIntoIoError::<_, B::Error>::new(stream);
let read = StreamReader::new(stream);
let read = M::apply(read, quality);
Self {
read,
buf: BytesMut::with_capacity(Self::INTERNAL_BUF_CAPACITY),
read_all_data: false,
tolerate_decode_errors: false,
}
}
pub(crate) fn with_tolerate_decode_errors(mut self, tolerate: bool) -> Self {
self.tolerate_decode_errors = tolerate;
self
}
}
impl<B, M> StreamingBody for WrapBody<M>
where
B: StreamingBody<Error: Into<BoxError>>,
M: DecorateAsyncRead<Input = AsyncReadBody<B>>,
{
type Data = Bytes;
type Error = BoxError;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let mut this = self.project();
if !*this.read_all_data {
if this.buf.capacity() == 0 {
this.buf.reserve(Self::INTERNAL_BUF_CAPACITY);
}
let result =
rama_core::stream::io::poll_read_buf(this.read.as_mut(), cx, &mut this.buf);
match ready!(result) {
Ok(0) => {
*this.read_all_data = true;
}
Ok(_) => {
let chunk = this.buf.split().freeze();
return Poll::Ready(Some(Ok(Frame::data(chunk))));
}
Err(err) => {
let body_error: Option<B::Error> = M::get_pin_mut(this.read.as_mut())
.get_pin_mut()
.project()
.error
.take();
let read_some_data = M::get_pin_mut(this.read.as_mut())
.get_pin_mut()
.project()
.read_some_data;
if let Some(body_error) = body_error {
return Poll::Ready(Some(Err(body_error.into())));
} else if err.raw_os_error() == Some(SENTINEL_ERROR_CODE) {
unreachable!()
} else if *read_some_data {
if err.kind() == ErrorKind::UnexpectedEof
&& M::get_pin_mut(this.read.as_mut())
.get_pin_mut()
.inner
.yielded_all_data
{
*this.read_all_data = true;
} else if *this.tolerate_decode_errors {
rama_core::telemetry::tracing::debug!(
"decompression: tolerating mid-stream decode error ({err}); ending body cleanly"
);
return Poll::Ready(None);
} else {
return Poll::Ready(Some(Err(err.into())));
}
}
}
}
}
let body = M::get_pin_mut(this.read).get_pin_mut().get_pin_mut();
match ready!(body.poll_frame(cx)) {
Some(Ok(frame)) if frame.is_trailers() => Poll::Ready(Some(Ok(
frame.map_data(|mut data| data.copy_to_bytes(data.remaining()))
))),
Some(Ok(frame)) => {
if let Ok(bytes) = frame.into_data()
&& bytes.has_remaining()
{
return Poll::Ready(Some(Err(
"there are extra bytes after body has been decompressed".into(),
)));
}
Poll::Ready(None)
}
Some(Err(err)) => Poll::Ready(Some(Err(err.into()))),
None => Poll::Ready(None),
}
}
}
pin_project! {
pub(crate) struct BodyIntoStream<B>
where
B: StreamingBody,
{
#[pin]
body: B,
yielded_all_data: bool,
non_data_frame: Option<Frame<B::Data>>,
}
}
#[allow(dead_code)]
impl<B> BodyIntoStream<B>
where
B: StreamingBody,
{
pub(crate) fn new(body: B) -> Self {
Self {
body,
yielded_all_data: false,
non_data_frame: None,
}
}
pub(crate) fn get_ref(&self) -> &B {
&self.body
}
pub(crate) fn get_mut(&mut self) -> &mut B {
&mut self.body
}
pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut B> {
self.project().body
}
pub(crate) fn into_inner(self) -> B {
self.body
}
}
impl<B> Stream for BodyIntoStream<B>
where
B: StreamingBody,
{
type Item = Result<B::Data, B::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
let this = self.as_mut().project();
if *this.yielded_all_data {
return Poll::Ready(None);
}
match std::task::ready!(this.body.poll_frame(cx)) {
Some(Ok(frame)) => match frame.into_data() {
Ok(data) => return Poll::Ready(Some(Ok(data))),
Err(frame) => {
*this.yielded_all_data = true;
*this.non_data_frame = Some(frame);
}
},
Some(Err(err)) => return Poll::Ready(Some(Err(err))),
None => {
*this.yielded_all_data = true;
}
}
}
}
}
impl<B> StreamingBody for BodyIntoStream<B>
where
B: StreamingBody,
{
type Data = B::Data;
type Error = B::Error;
fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
if let Some(frame) = std::task::ready!(self.as_mut().poll_next(cx)) {
return Poll::Ready(Some(frame.map(Frame::data)));
}
let this = self.project();
if let Some(frame) = this.non_data_frame.take() {
return Poll::Ready(Some(Ok(frame)));
}
this.body.poll_frame(cx)
}
#[inline]
fn size_hint(&self) -> SizeHint {
self.body.size_hint()
}
}
pin_project! {
pub(crate) struct StreamErrorIntoIoError<S, E> {
#[pin]
inner: S,
error: Option<E>,
read_some_data: bool,
}
}
impl<S, E> StreamErrorIntoIoError<S, E> {
pub(crate) fn new(inner: S) -> Self {
Self {
inner,
error: None,
read_some_data: false,
}
}
pub(crate) fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
self.project().inner
}
}
impl<S, T, E> Stream for StreamErrorIntoIoError<S, E>
where
S: Stream<Item = Result<T, E>>,
{
type Item = Result<T, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
match ready!(this.inner.poll_next(cx)) {
None => Poll::Ready(None),
Some(Ok(value)) => {
*this.read_some_data = true;
Poll::Ready(Some(Ok(value)))
}
Some(Err(err)) => {
*this.error = Some(err);
Poll::Ready(Some(Err(io::Error::from_raw_os_error(SENTINEL_ERROR_CODE))))
}
}
}
}
pub(crate) const SENTINEL_ERROR_CODE: i32 = -837459418;
#[non_exhaustive]
#[derive(Default, Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub enum CompressionLevel {
Fastest,
Best,
#[default]
Default,
Precise(u32),
}
use async_compression::Level as AsyncCompressionLevel;
use compression_core::Level as CompressionCoreLevel;
impl CompressionLevel {
#[allow(dead_code)]
pub(crate) fn into_async_compression(self) -> AsyncCompressionLevel {
match self {
Self::Fastest => AsyncCompressionLevel::Fastest,
Self::Best => AsyncCompressionLevel::Best,
Self::Default => AsyncCompressionLevel::Default,
Self::Precise(quality) => {
AsyncCompressionLevel::Precise(quality.try_into().unwrap_or(i32::MAX))
}
}
}
}
impl CompressionLevel {
#[allow(dead_code)]
pub(crate) fn into_compression_core(self) -> CompressionCoreLevel {
match self {
Self::Fastest => CompressionCoreLevel::Fastest,
Self::Best => CompressionCoreLevel::Best,
Self::Default => CompressionCoreLevel::Default,
Self::Precise(quality) => {
CompressionCoreLevel::Precise(quality.try_into().unwrap_or(i32::MAX))
}
}
}
}