use futures::TryStreamExt;
use http::header::CONTENT_LENGTH;
use ntex::{
http::error::PayloadError,
util::{Bytes, BytesMut},
web::{self, HttpRequest},
};
use strum::IntoStaticStr;
use crate::pipeline::request_extensions::write_request_body_size;
#[derive(Debug, thiserror::Error, IntoStaticStr)]
pub enum ReadBodyStreamError {
#[error("Failed to read request body: {0}")]
#[strum(serialize = "PAYLOAD_READ_ERROR")]
PayloadReadError(#[from] PayloadError),
#[error("Content-Length header has invalid value")]
#[strum(serialize = "INVALID_HEADER")]
InvalidContentLengthHeader,
#[error("Content-Length exceeds the maximum allowed size: {0}")]
#[strum(serialize = "PAYLOAD_TOO_LARGE_CONTENT_LENGTH")]
PayloadTooLargeContentLength(usize),
#[error("Request body exceeds the maximum allowed size while reading the stream")]
#[strum(serialize = "PAYLOAD_TOO_LARGE_BODY_STREAM")]
PayloadTooLargeBodyStream,
}
impl ReadBodyStreamError {
pub fn status_code(&self) -> http::StatusCode {
match self {
Self::PayloadReadError(_) => http::StatusCode::UNPROCESSABLE_ENTITY,
Self::InvalidContentLengthHeader => http::StatusCode::BAD_REQUEST,
Self::PayloadTooLargeContentLength(_) | Self::PayloadTooLargeBodyStream => {
http::StatusCode::PAYLOAD_TOO_LARGE
}
}
}
pub fn error_code(&self) -> &'static str {
self.into()
}
}
#[inline]
pub async fn read_body_stream(
req: &HttpRequest,
mut body_stream: web::types::Payload,
max_size: usize,
) -> Result<Bytes, ReadBodyStreamError> {
let content_length: Option<usize> = {
let content_length_header = req.headers().get(CONTENT_LENGTH);
if let Some(content_length_header) = content_length_header {
let content_length_str = content_length_header
.to_str()
.map_err(|_| ReadBodyStreamError::InvalidContentLengthHeader)?;
let content_length: usize = content_length_str
.parse()
.map_err(|_| ReadBodyStreamError::InvalidContentLengthHeader)?;
if content_length > max_size {
write_request_body_size(req, content_length as u64);
return Err(ReadBodyStreamError::PayloadTooLargeContentLength(max_size));
}
Some(content_length)
} else {
None
}
};
let mut body = if let Some(content_length) = content_length {
BytesMut::with_capacity(content_length)
} else {
BytesMut::new()
};
while let Some(chunk) = body_stream.try_next().await? {
if chunk.len() > max_size.saturating_sub(body.len()) {
write_request_body_size(req, (body.len() + chunk.len()) as u64);
return Err(ReadBodyStreamError::PayloadTooLargeBodyStream);
}
body.extend_from_slice(&chunk);
}
Ok(body.freeze())
}