#![cfg_attr(docsrs, doc(cfg(feature = "protobuf")))]
use http::StatusCode;
use http_body_util::BodyExt;
use prost::Message;
use crate::extractors::FromRequest;
use crate::responder::Responder;
use crate::types::Request;
#[doc(alias = "protobuf")]
pub struct Protobuf<T>(pub T);
#[derive(Debug)]
pub enum ProtobufError {
InvalidContentType,
MissingContentType,
BodyReadError(String),
ProtobufDecodeError(String),
}
impl Responder for ProtobufError {
fn into_response(self) -> crate::types::Response {
match self {
ProtobufError::InvalidContentType => (
StatusCode::BAD_REQUEST,
"Invalid content type; expected application/x-protobuf or application/protobuf",
)
.into_response(),
ProtobufError::MissingContentType => {
(StatusCode::BAD_REQUEST, "Missing content type header").into_response()
}
ProtobufError::BodyReadError(err) => (
StatusCode::BAD_REQUEST,
format!("Failed to read request body: {}", err),
)
.into_response(),
ProtobufError::ProtobufDecodeError(err) => (
StatusCode::BAD_REQUEST,
format!("Failed to decode protobuf: {}", err),
)
.into_response(),
}
}
}
impl<T> Responder for Protobuf<T>
where
T: Message,
{
fn into_response(self) -> crate::types::Response {
let buf = self.0.encode_to_vec();
let mut res = crate::types::Response::new(crate::body::TakoBody::from(buf));
res.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/x-protobuf"),
);
res
}
}
fn is_protobuf_content_type(headers: &http::HeaderMap) -> bool {
headers
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|ct| {
ct == "application/x-protobuf"
|| ct == "application/protobuf"
|| ct.starts_with("application/x-protobuf;")
|| ct.starts_with("application/protobuf;")
})
.unwrap_or(false)
}
impl<'a, T> FromRequest<'a> for Protobuf<T>
where
T: Message + Default + Send + 'static,
{
type Error = ProtobufError;
fn from_request(
req: &'a mut Request,
) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
async move {
if !is_protobuf_content_type(req.headers()) {
return Err(ProtobufError::InvalidContentType);
}
let body_bytes = req
.body_mut()
.collect()
.await
.map_err(|e| ProtobufError::BodyReadError(e.to_string()))?
.to_bytes();
let data = T::decode(&body_bytes[..])
.map_err(|e| ProtobufError::ProtobufDecodeError(e.to_string()))?;
Ok(Protobuf(data))
}
}
}