use bytes::{Buf, Bytes};
use http::HeaderMap;
use http_body::{Body, Frame, SizeHint};
use hyper::body::Incoming;
use serde::de::DeserializeOwned;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::ptr;
use std::rc::Rc;
use std::sync::atomic::{Ordering, compiler_fence};
use std::task::{Context, Poll, ready};
use crate::error::Error;
mod sealed {
pub trait Sealed {}
impl Sealed for super::Aggregate {}
impl Sealed for bytes::Bytes {}
#[cfg(feature = "tokio-tungstenite")]
impl Sealed for tungstenite::protocol::Message {}
#[cfg(feature = "tokio-websockets")]
impl Sealed for tokio_websockets::Message {}
}
pub trait Payload: sealed::Sealed + Sized {
fn coalesce(self) -> Vec<u8>;
fn z_coalesce(self) -> Result<Vec<u8>, Self>;
fn json<T>(self) -> Result<T, Error>
where
T: DeserializeOwned;
fn z_json<T>(self) -> Result<Result<T, Error>, Self>
where
T: DeserializeOwned,
{
self.z_coalesce()
.map(|data| deserialize_json(data.as_slice()))
}
fn bez_json<T>(self) -> Result<T, Error>
where
T: DeserializeOwned,
{
self.z_json().unwrap_or_else(Self::json)
}
fn utf8(self) -> Result<String, Error> {
deserialize_utf8(self.coalesce())
}
fn z_utf8(self) -> Result<Result<String, Error>, Self> {
self.z_coalesce().map(deserialize_utf8)
}
fn bez_utf8(self) -> Result<String, Error> {
self.z_utf8().unwrap_or_else(Self::utf8)
}
}
pub struct Aggregate {
payload: RequestPayload,
_unsend: PhantomData<Rc<()>>,
}
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Coalesce {
body: RequestBody,
payload: Option<RequestPayload>,
}
#[derive(Debug)]
pub struct RequestBody {
has_capacity: Option<bool>,
remaining: usize,
body: Incoming,
}
struct RequestPayload {
frames: Vec<Bytes>,
trailers: Option<HeaderMap>,
}
#[inline]
fn deserialize_json<T>(buf: &[u8]) -> Result<T, Error>
where
T: DeserializeOwned,
{
serde_json::from_slice(buf).map_err(Error::de_json)
}
#[inline]
fn deserialize_utf8(data: Vec<u8>) -> Result<String, Error> {
String::from_utf8(data).map_err(|_| Error::invalid_utf8_sequence("request body"))
}
#[inline(never)]
unsafe fn unfenced_zeroize(frame: &mut Bytes) {
let len = frame.remaining();
let ptr = frame.as_ptr() as *mut u8;
for idx in 0..len {
unsafe {
ptr::write_volatile(ptr.add(idx), 0);
}
}
frame.advance(len);
}
#[inline(always)]
fn release_compiler_fence() {
compiler_fence(Ordering::Release);
}
impl Aggregate {
pub fn trailers(&self) -> Option<&HeaderMap> {
self.payload.trailers.as_ref()
}
pub fn is_empty(&self) -> bool {
self.len().is_some_and(|len| len == 0)
}
#[inline]
pub fn len(&self) -> Option<usize> {
self.payload
.frames()
.iter()
.map(Buf::remaining)
.try_fold(0usize, |len, remaining| len.checked_add(remaining))
}
}
impl Aggregate {
fn new(payload: RequestPayload) -> Self {
Self {
payload,
_unsend: PhantomData,
}
}
}
#[cfg(any(feature = "tokio-tungstenite", feature = "tokio-websockets"))]
macro_rules! impl_payload_for_bytes_like {
($ty:ty) => {
impl_payload_for_bytes_like!($ty, |this| this, From::from);
};
($ty:ty, $from:expr) => {
impl_payload_for_bytes_like!($ty, $from, From::from);
};
($ty:ty, $from:expr, $into:expr) => {
impl Payload for $ty {
fn coalesce(self) -> Vec<u8> {
Payload::coalesce(Bytes::from($from(self)))
}
fn z_coalesce(self) -> Result<Vec<u8>, Self> {
Payload::z_coalesce(Bytes::from($from(self))).map_err($into)
}
fn json<T>(self) -> Result<T, Error>
where
T: DeserializeOwned,
{
Payload::json(Bytes::from($from(self)))
}
fn z_json<T>(self) -> Result<Result<T, Error>, Self>
where
T: DeserializeOwned,
{
Payload::z_json(Bytes::from($from(self))).map_err($into)
}
fn z_utf8(self) -> Result<Result<String, Error>, Self> {
Payload::z_utf8(Bytes::from($from(self))).map_err($into)
}
}
};
}
impl Payload for Aggregate {
fn coalesce(mut self) -> Vec<u8> {
let mut dest = self.len().map(Vec::with_capacity).unwrap_or_default();
for frame in self.payload.frames_mut().iter_mut() {
dest.extend_from_slice(frame.as_ref());
frame.advance(frame.remaining());
}
dest
}
fn json<T>(mut self) -> Result<T, Error>
where
T: DeserializeOwned,
{
if let [frame] = self.payload.frames_mut() {
let result = deserialize_json(frame.as_ref());
frame.advance(frame.remaining());
return result;
}
deserialize_json(self.coalesce().as_slice())
}
fn z_json<T>(mut self) -> Result<Result<T, Error>, Self>
where
T: DeserializeOwned,
{
if let [frame] = self.payload.frames_mut() {
if !frame.is_unique() {
return Err(self);
}
let result = deserialize_json(frame.as_ref());
unsafe {
unfenced_zeroize(frame);
}
release_compiler_fence();
return Ok(result);
}
self.z_coalesce()
.map(|data| deserialize_json(data.as_slice()))
}
fn z_coalesce(mut self) -> Result<Vec<u8>, Self> {
let mut dest = self.len().map(Vec::with_capacity).unwrap_or_default();
let payload = &mut self.payload;
if !payload.frames().iter().all(Bytes::is_unique) {
return Err(self);
}
for frame in payload.frames_mut().iter_mut() {
dest.extend_from_slice(frame.as_ref());
unsafe {
unfenced_zeroize(frame);
}
}
release_compiler_fence();
Ok(dest)
}
}
#[cfg(feature = "tokio-tungstenite")]
impl_payload_for_bytes_like!(tungstenite::protocol::Message);
#[cfg(feature = "tokio-websockets")]
impl_payload_for_bytes_like!(
tokio_websockets::Message,
tokio_websockets::Message::into_payload,
tokio_websockets::Message::binary
);
impl Payload for Bytes {
fn coalesce(mut self) -> Vec<u8> {
let mut dest = Vec::with_capacity(self.remaining());
dest.extend_from_slice(self.as_ref());
self.advance(self.remaining());
dest
}
fn z_coalesce(mut self) -> Result<Vec<u8>, Self> {
if !self.is_unique() {
return Err(self);
}
let mut dest = Vec::with_capacity(self.remaining());
dest.extend_from_slice(self.as_ref());
unsafe {
unfenced_zeroize(&mut self);
}
release_compiler_fence();
Ok(dest)
}
fn json<T>(mut self) -> Result<T, Error>
where
T: DeserializeOwned,
{
let result = deserialize_json(self.as_ref());
self.advance(self.remaining());
result
}
fn z_json<T>(mut self) -> Result<Result<T, Error>, Self>
where
T: DeserializeOwned,
{
if !self.is_unique() {
return Err(self);
}
let result = deserialize_json(self.as_ref());
unsafe {
unfenced_zeroize(&mut self);
}
release_compiler_fence();
Ok(result)
}
}
impl Coalesce {
pub(super) fn new(body: RequestBody) -> Self {
Self {
body,
payload: Some(RequestPayload::new()),
}
}
}
fn already_read() -> Error {
Error::new("a request body can only be read once.")
}
fn unknown_frame_type() -> Error {
Error::new("unknown frame type received while reading a request body.")
}
impl Future for Coalesce {
type Output = Result<Aggregate, Error>;
fn poll(mut self: Pin<&mut Self>, context: &mut Context) -> Poll<Self::Output> {
while let Some(frame) = ready!(Pin::new(&mut self.body).poll_frame(context)?) {
let payload = self.payload.as_mut().ok_or_else(already_read)?;
match frame.into_data() {
Ok(data) => payload.frames.push(data),
Err(frame) => {
let trailers = frame.into_trailers().map_err(|_| unknown_frame_type())?;
if let Some(existing) = payload.trailers.as_mut() {
existing.extend(trailers);
} else {
payload.trailers = Some(trailers);
}
}
}
}
Poll::Ready(
self.payload
.take()
.map(Aggregate::new)
.ok_or_else(already_read),
)
}
}
impl RequestBody {
pub(crate) fn new(body: Incoming, remaining: usize) -> Self {
Self {
has_capacity: None,
remaining,
body,
}
}
}
impl Body for RequestBody {
type Data = Bytes;
type Error = Error;
fn poll_frame(
mut self: Pin<&mut Self>,
context: &mut Context,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
if self.has_capacity.is_none() {
self.has_capacity = Some(self.body.size_hint().exact().is_none_or(|upper| {
u64::try_from(self.remaining).is_ok_and(|remaining| remaining >= upper)
}));
}
if self.remaining == 0 || self.has_capacity.is_some_and(|has_capacity| !has_capacity) {
return Poll::Ready(Some(Err(Error::payload_too_large())));
}
let Some(frame) = ready!(Pin::new(&mut self.body).poll_frame(context)?) else {
return Poll::Ready(None);
};
if let Some(data) = frame.data_ref() {
self.remaining = self
.remaining
.checked_sub(data.remaining())
.ok_or_else(|| {
self.remaining = 0;
Error::payload_too_large()
})?;
}
Poll::Ready(Some(Ok(frame)))
}
fn is_end_stream(&self) -> bool {
self.remaining == 0
|| self.has_capacity.is_some_and(|has_capacity| !has_capacity)
|| self.body.is_end_stream()
}
fn size_hint(&self) -> SizeHint {
let Ok(remaining) = u64::try_from(self.remaining) else {
let mut hint = SizeHint::new();
hint.set_lower(self.body.size_hint().lower());
if cfg!(debug_assertions) {
use std::sync::Once;
static ONCE: Once = Once::new();
ONCE.call_once(|| {
print!("warn: a lossy size hint must be used for RequestBody. ");
println!("usize::MAX exceeds u64::MAX on this platform.");
});
}
return hint;
};
let mut hint = self.body.size_hint();
if remaining < hint.lower() {
hint.set_exact(remaining);
} else {
let upper = hint
.upper()
.map(|upper| upper.min(remaining))
.unwrap_or(remaining);
hint.set_upper(upper);
}
hint
}
}
impl RequestPayload {
fn new() -> Self {
Self {
frames: Vec::with_capacity(9),
trailers: None,
}
}
#[inline]
fn frames(&self) -> &[Bytes] {
&self.frames
}
#[inline]
fn frames_mut(&mut self) -> &mut [Bytes] {
&mut self.frames
}
}