use {crate::{content_types::APPLICATION_PROTOBUF,
error::JetError},
axum::{async_trait,
body::Bytes,
extract::{FromRequest,
Request},
http::header::CONTENT_TYPE},
prost::Message,
std::marker::PhantomData};
const MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
pub struct ProtobufRequest<T>(pub T)
where
T: Message + Default;
#[async_trait]
impl<S, T> FromRequest<S> for ProtobufRequest<T>
where
S: Send + Sync,
T: Message + Default,
{
type Rejection = JetError;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let content_type = req
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !content_type.starts_with(APPLICATION_PROTOBUF) {
return Err(JetError::InvalidContentType {
expected: APPLICATION_PROTOBUF.to_string(),
actual: content_type.to_string(),
});
}
let bytes = Bytes::from_request(req, state)
.await
.map_err(|e| JetError::BadRequest(format!("Failed to read body: {}", e)))?;
if bytes.len() > MAX_BODY_SIZE {
return Err(JetError::BodyTooLarge {
size: bytes.len(),
max: MAX_BODY_SIZE,
});
}
let message = T::decode(bytes)?;
Ok(ProtobufRequest(message))
}
}
pub struct OptionalProtobufRequest<T>(pub Option<T>)
where
T: Message + Default;
#[async_trait]
impl<S, T> FromRequest<S> for OptionalProtobufRequest<T>
where
S: Send + Sync,
T: Message + Default,
{
type Rejection = JetError;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let bytes = Bytes::from_request(req, state)
.await
.map_err(|e| JetError::BadRequest(format!("Failed to read body: {}", e)))?;
if bytes.is_empty() {
return Ok(OptionalProtobufRequest(None));
}
if bytes.len() > MAX_BODY_SIZE {
return Err(JetError::BodyTooLarge {
size: bytes.len(),
max: MAX_BODY_SIZE,
});
}
let message = T::decode(bytes)?;
Ok(OptionalProtobufRequest(Some(message)))
}
}
pub struct ProtobufRequestWithLimit<T, const LIMIT: usize>(pub T, PhantomData<T>)
where
T: Message + Default;
#[async_trait]
impl<S, T, const LIMIT: usize> FromRequest<S> for ProtobufRequestWithLimit<T, LIMIT>
where
S: Send + Sync,
T: Message + Default,
{
type Rejection = JetError;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let bytes = Bytes::from_request(req, state)
.await
.map_err(|e| JetError::BadRequest(format!("Failed to read body: {}", e)))?;
if bytes.len() > LIMIT {
return Err(JetError::BodyTooLarge {
size: bytes.len(),
max: LIMIT,
});
}
let message = T::decode(bytes)?;
Ok(ProtobufRequestWithLimit(message, PhantomData))
}
}