use bytes::{Buf, Bytes};
use http::{HeaderMap, StatusCode};
use http_body::Body;
use http_body_util::{LengthLimitError, Limited};
use hyper::body::Incoming;
use serde::Deserialize;
use serde::de::DeserializeOwned;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::ptr;
use std::rc::Rc;
use std::sync::atomic::{Ordering, compiler_fence};
use std::task::{Context, Poll, ready};
use crate::error::{BoxError, Error};
use crate::raise;
mod sealed {
pub trait Sealed {}
impl Sealed for super::Aggregate {}
impl Sealed for bytes::Bytes {}
#[cfg(feature = "tokio-tungstenite")]
impl Sealed for tungstenite::protocol::Message {}
#[cfg(feature = "tokio-websockets")]
impl Sealed for tokio_websockets::Message {}
}
pub trait Payload: sealed::Sealed + Sized {
fn coalesce(self) -> Vec<u8>;
fn z_coalesce(self) -> Result<Vec<u8>, Self>;
fn json<T>(self) -> Result<T, Error>
where
T: DeserializeOwned;
fn z_json<T>(self) -> Result<Result<T, Error>, Self>
where
T: DeserializeOwned,
{
self.z_coalesce()
.map(|data| deserialize_json(data.as_slice()))
}
fn bez_json<T>(self) -> Result<T, Error>
where
T: DeserializeOwned,
{
self.z_json().unwrap_or_else(Self::json)
}
fn utf8(self) -> Result<String, Error> {
deserialize_utf8(self.coalesce())
}
fn z_utf8(self) -> Result<Result<String, Error>, Self> {
self.z_coalesce().map(deserialize_utf8)
}
fn bez_utf8(self) -> Result<String, Error> {
self.z_utf8().unwrap_or_else(Self::utf8)
}
}
pub struct Aggregate {
payload: RequestPayload,
_unsend: PhantomData<Rc<()>>,
}
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Coalesce {
body: Limited<Incoming>,
payload: Option<RequestPayload>,
}
struct RequestPayload {
frames: Vec<Bytes>,
trailers: Option<HeaderMap>,
}
fn already_read<T>() -> Result<T, Error> {
raise!(500, message = "The request body has already been read.")
}
fn deserialize_json<T>(buf: &[u8]) -> Result<T, Error>
where
T: DeserializeOwned,
{
#[derive(Deserialize)]
struct Tagged<D> {
data: D,
}
match serde_json::from_slice(buf) {
Ok(Tagged { data }) => Ok(data),
Err(error) => raise!(400, error),
}
}
#[inline]
fn deserialize_utf8(data: Vec<u8>) -> Result<String, Error> {
String::from_utf8(data).or_else(|error| raise!(400, error.utf8_error()))
}
fn into_future_error<T>(error: BoxError) -> Result<T, Error> {
if error.is::<LengthLimitError>() {
raise!(413, boxed = error);
}
raise!(400, boxed = error);
}
#[inline(never)]
unsafe fn unfenced_zeroize(frame: &mut Bytes) {
let len = frame.remaining();
let ptr = frame.as_ptr() as *mut u8;
for idx in 0..len {
unsafe {
ptr::write_volatile(ptr.add(idx), 0);
}
}
frame.advance(len);
}
#[inline(always)]
fn release_compiler_fence() {
compiler_fence(Ordering::Release);
}
impl Aggregate {
pub fn trailers(&self) -> Option<&HeaderMap> {
self.payload.trailers.as_ref()
}
pub fn is_empty(&self) -> bool {
self.len().is_some_and(|len| len == 0)
}
#[inline]
pub fn len(&self) -> Option<usize> {
self.payload
.frames()
.iter()
.map(Buf::remaining)
.try_fold(0usize, |len, remaining| len.checked_add(remaining))
}
}
impl Aggregate {
fn new(payload: RequestPayload) -> Self {
Self {
payload,
_unsend: PhantomData,
}
}
}
#[cfg(any(feature = "tokio-tungstenite", feature = "tokio-websockets"))]
macro_rules! impl_payload_for_bytes_like {
($ty:ty) => {
impl_payload_for_bytes_like!($ty, |this| this, From::from);
};
($ty:ty, $from:expr) => {
impl_payload_for_bytes_like!($ty, $from, From::from);
};
($ty:ty, $from:expr, $into:expr) => {
impl Payload for $ty {
fn coalesce(self) -> Vec<u8> {
Payload::coalesce(Bytes::from($from(self)))
}
fn z_coalesce(self) -> Result<Vec<u8>, Self> {
Payload::z_coalesce(Bytes::from($from(self))).map_err($into)
}
fn json<T>(self) -> Result<T, Error>
where
T: DeserializeOwned,
{
Payload::json(Bytes::from($from(self)))
}
fn z_json<T>(self) -> Result<Result<T, Error>, Self>
where
T: DeserializeOwned,
{
Payload::z_json(Bytes::from($from(self))).map_err($into)
}
fn z_utf8(self) -> Result<Result<String, Error>, Self> {
Payload::z_utf8(Bytes::from($from(self))).map_err($into)
}
}
};
}
impl Payload for Aggregate {
fn coalesce(mut self) -> Vec<u8> {
let mut dest = self.len().map(Vec::with_capacity).unwrap_or_default();
for frame in self.payload.frames_mut().iter_mut() {
dest.extend_from_slice(frame.as_ref());
frame.advance(frame.remaining());
}
dest
}
fn json<T>(mut self) -> Result<T, Error>
where
T: DeserializeOwned,
{
if let [frame] = self.payload.frames_mut() {
let result = deserialize_json(frame.as_ref());
frame.advance(frame.remaining());
return result;
}
deserialize_json(self.coalesce().as_slice())
}
fn z_json<T>(mut self) -> Result<Result<T, Error>, Self>
where
T: DeserializeOwned,
{
if let [frame] = self.payload.frames_mut() {
if !frame.is_unique() {
return Err(self);
}
let result = deserialize_json(frame.as_ref());
unsafe {
unfenced_zeroize(frame);
}
release_compiler_fence();
return Ok(result);
}
self.z_coalesce()
.map(|data| deserialize_json(data.as_slice()))
}
fn z_coalesce(mut self) -> Result<Vec<u8>, Self> {
let mut dest = self.len().map(Vec::with_capacity).unwrap_or_default();
let payload = &mut self.payload;
if !payload.frames().iter().all(Bytes::is_unique) {
return Err(self);
}
for frame in payload.frames_mut().iter_mut() {
dest.extend_from_slice(frame.as_ref());
unsafe {
unfenced_zeroize(frame);
}
}
release_compiler_fence();
Ok(dest)
}
}
#[cfg(feature = "tokio-tungstenite")]
impl_payload_for_bytes_like!(tungstenite::protocol::Message);
#[cfg(feature = "tokio-websockets")]
impl_payload_for_bytes_like!(
tokio_websockets::Message,
tokio_websockets::Message::into_payload,
tokio_websockets::Message::binary
);
impl Payload for Bytes {
fn coalesce(mut self) -> Vec<u8> {
let mut dest = Vec::with_capacity(self.remaining());
dest.extend_from_slice(self.as_ref());
self.advance(self.remaining());
dest
}
fn z_coalesce(mut self) -> Result<Vec<u8>, Self> {
if !self.is_unique() {
return Err(self);
}
let mut dest = Vec::with_capacity(self.remaining());
dest.extend_from_slice(self.as_ref());
unsafe {
unfenced_zeroize(&mut self);
}
release_compiler_fence();
Ok(dest)
}
fn json<T>(mut self) -> Result<T, Error>
where
T: DeserializeOwned,
{
let result = deserialize_json(self.as_ref());
self.advance(self.remaining());
result
}
fn z_json<T>(mut self) -> Result<Result<T, Error>, Self>
where
T: DeserializeOwned,
{
if !self.is_unique() {
return Err(self);
}
let result = deserialize_json(self.as_ref());
unsafe {
unfenced_zeroize(&mut self);
}
release_compiler_fence();
Ok(result)
}
}
impl Coalesce {
pub(super) fn new(body: Limited<Incoming>) -> Self {
Self {
body,
payload: Some(RequestPayload::new()),
}
}
}
impl Future for Coalesce {
type Output = Result<Aggregate, Error>;
fn poll(self: Pin<&mut Self>, context: &mut Context) -> Poll<Self::Output> {
let Self { body, payload } = self.get_mut();
let mut body = Pin::new(body);
loop {
let Some(result) = ready!(body.as_mut().poll_frame(context)) else {
return Poll::Ready(match payload.take() {
Some(payload) => Ok(Aggregate::new(payload)),
None => already_read(),
});
};
let frame = result.or_else(into_future_error)?;
let payload = payload.as_mut().map_or_else(already_read, Ok)?;
let trailers = match frame.into_data() {
Ok(data) => {
payload.frames.push(data);
continue;
}
Err(frame) => {
let Ok(trailers) = frame.into_trailers() else {
return Poll::Ready(Err(Error::with_status(
StatusCode::BAD_REQUEST,
"unexpected frame type received while reading the request body",
)));
};
trailers
}
};
if let Some(existing) = payload.trailers.as_mut() {
existing.extend(trailers);
} else {
payload.trailers = Some(trailers);
}
}
}
}
impl RequestPayload {
fn new() -> Self {
Self {
frames: Vec::with_capacity(9),
trailers: None,
}
}
#[inline]
fn frames(&self) -> &[Bytes] {
self.frames.as_slice()
}
#[inline]
fn frames_mut(&mut self) -> &mut [Bytes] {
self.frames.as_mut_slice()
}
}