use axum_core::__composite_rejection as composite_rejection;
use axum_core::__define_rejection as define_rejection;
use axum_core::{
extract::{rejection::BytesRejection, FromRequest, Request},
response::{IntoResponse, Response},
RequestExt,
};
use bytes::BytesMut;
use http::StatusCode;
use http_body_util::BodyExt;
use prost::Message;
#[derive(Debug, Clone, Copy, Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "protobuf")))]
#[must_use]
pub struct Protobuf<T>(pub T);
impl<T, S> FromRequest<S> for Protobuf<T>
where
T: Message + Default,
S: Send + Sync,
{
type Rejection = ProtobufRejection;
async fn from_request(req: Request, _: &S) -> Result<Self, Self::Rejection> {
let mut buf = req
.into_limited_body()
.collect()
.await
.map_err(ProtobufDecodeError)?
.aggregate();
match T::decode(&mut buf) {
Ok(value) => Ok(Protobuf(value)),
Err(err) => Err(ProtobufDecodeError::from_err(err).into()),
}
}
}
axum_core::__impl_deref!(Protobuf);
impl<T> From<T> for Protobuf<T> {
fn from(inner: T) -> Self {
Self(inner)
}
}
impl<T> IntoResponse for Protobuf<T>
where
T: Message + Default,
{
fn into_response(self) -> Response {
let mut buf = BytesMut::with_capacity(self.0.encoded_len());
match &self.0.encode(&mut buf) {
Ok(()) => buf.into_response(),
Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
}
}
}
define_rejection! {
#[status = UNPROCESSABLE_ENTITY]
#[body = "Failed to decode the body"]
pub struct ProtobufDecodeError(Error);
}
composite_rejection! {
pub enum ProtobufRejection {
ProtobufDecodeError,
BytesRejection,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{routing::post, Router};
#[tokio::test]
async fn decode_body() {
#[derive(prost::Message)]
struct Input {
#[prost(string, tag = "1")]
foo: String,
}
let app = Router::new().route(
"/",
post(|input: Protobuf<Input>| async move { input.foo.to_owned() }),
);
let input = Input {
foo: "bar".to_owned(),
};
let client = TestClient::new(app);
let res = client.post("/").body(input.encode_to_vec()).await;
let body = res.text().await;
assert_eq!(body, "bar");
}
#[tokio::test]
async fn prost_decode_error() {
#[derive(prost::Message)]
struct Input {
#[prost(string, tag = "1")]
foo: String,
}
#[derive(prost::Message)]
struct Expected {
#[prost(int32, tag = "1")]
test: i32,
}
let app = Router::new().route("/", post(|_: Protobuf<Expected>| async {}));
let input = Input {
foo: "bar".to_owned(),
};
let client = TestClient::new(app);
let res = client.post("/").body(input.encode_to_vec()).await;
assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
}
#[tokio::test]
async fn encode_body() {
#[derive(prost::Message)]
struct Input {
#[prost(string, tag = "1")]
foo: String,
}
#[derive(prost::Message)]
struct Output {
#[prost(string, tag = "1")]
result: String,
}
#[axum::debug_handler]
async fn handler(input: Protobuf<Input>) -> Protobuf<Output> {
let output = Output {
result: input.foo.to_owned(),
};
Protobuf(output)
}
let app = Router::new().route("/", post(handler));
let input = Input {
foo: "bar".to_owned(),
};
let client = TestClient::new(app);
let res = client.post("/").body(input.encode_to_vec()).await;
assert_eq!(
res.headers()["content-type"],
mime::APPLICATION_OCTET_STREAM.as_ref()
);
let body = res.bytes().await;
let output = Output::decode(body).unwrap();
assert_eq!(output.result, "bar");
}
}