1use axum_core::__composite_rejection as composite_rejection;
4use axum_core::__define_rejection as define_rejection;
5use axum_core::{
6 extract::{rejection::BytesRejection, FromRequest, Request},
7 response::{IntoResponse, Response},
8 RequestExt,
9};
10use bytes::BytesMut;
11use http::StatusCode;
12use http_body_util::BodyExt;
13use prost::Message;
14
15#[derive(Debug, Clone, Copy, Default)]
92#[cfg_attr(docsrs, doc(cfg(feature = "protobuf")))]
93#[must_use]
94pub struct Protobuf<T>(pub T);
95
96impl<T, S> FromRequest<S> for Protobuf<T>
97where
98 T: Message + Default,
99 S: Send + Sync,
100{
101 type Rejection = ProtobufRejection;
102
103 async fn from_request(req: Request, _: &S) -> Result<Self, Self::Rejection> {
104 let mut buf = req
105 .into_limited_body()
106 .collect()
107 .await
108 .map_err(ProtobufDecodeError)?
109 .aggregate();
110
111 match T::decode(&mut buf) {
112 Ok(value) => Ok(Protobuf(value)),
113 Err(err) => Err(ProtobufDecodeError::from_err(err).into()),
114 }
115 }
116}
117
118axum_core::__impl_deref!(Protobuf);
119
120impl<T> From<T> for Protobuf<T> {
121 fn from(inner: T) -> Self {
122 Self(inner)
123 }
124}
125
126impl<T> IntoResponse for Protobuf<T>
127where
128 T: Message + Default,
129{
130 fn into_response(self) -> Response {
131 let mut buf = BytesMut::with_capacity(self.0.encoded_len());
132 match &self.0.encode(&mut buf) {
133 Ok(()) => buf.into_response(),
134 Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
135 }
136 }
137}
138
139define_rejection! {
140 #[status = UNPROCESSABLE_ENTITY]
141 #[body = "Failed to decode the body"]
142 pub struct ProtobufDecodeError(Error);
146}
147
148composite_rejection! {
149 pub enum ProtobufRejection {
154 ProtobufDecodeError,
155 BytesRejection,
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use crate::test_helpers::*;
163 use axum::{routing::post, Router};
164
165 #[tokio::test]
166 async fn decode_body() {
167 #[derive(prost::Message)]
168 struct Input {
169 #[prost(string, tag = "1")]
170 foo: String,
171 }
172
173 let app = Router::new().route(
174 "/",
175 post(|input: Protobuf<Input>| async move { input.foo.to_owned() }),
176 );
177
178 let input = Input {
179 foo: "bar".to_owned(),
180 };
181
182 let client = TestClient::new(app);
183 let res = client.post("/").body(input.encode_to_vec()).await;
184
185 let body = res.text().await;
186
187 assert_eq!(body, "bar");
188 }
189
190 #[tokio::test]
191 async fn prost_decode_error() {
192 #[derive(prost::Message)]
193 struct Input {
194 #[prost(string, tag = "1")]
195 foo: String,
196 }
197
198 #[derive(prost::Message)]
199 struct Expected {
200 #[prost(int32, tag = "1")]
201 test: i32,
202 }
203
204 let app = Router::new().route("/", post(|_: Protobuf<Expected>| async {}));
205
206 let input = Input {
207 foo: "bar".to_owned(),
208 };
209
210 let client = TestClient::new(app);
211 let res = client.post("/").body(input.encode_to_vec()).await;
212
213 assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
214 }
215
216 #[tokio::test]
217 async fn encode_body() {
218 #[derive(prost::Message)]
219 struct Input {
220 #[prost(string, tag = "1")]
221 foo: String,
222 }
223
224 #[derive(prost::Message)]
225 struct Output {
226 #[prost(string, tag = "1")]
227 result: String,
228 }
229
230 #[axum::debug_handler]
231 async fn handler(input: Protobuf<Input>) -> Protobuf<Output> {
232 let output = Output {
233 result: input.foo.to_owned(),
234 };
235
236 Protobuf(output)
237 }
238
239 let app = Router::new().route("/", post(handler));
240
241 let input = Input {
242 foo: "bar".to_owned(),
243 };
244
245 let client = TestClient::new(app);
246 let res = client.post("/").body(input.encode_to_vec()).await;
247
248 assert_eq!(
249 res.headers()["content-type"],
250 mime::APPLICATION_OCTET_STREAM.as_ref()
251 );
252
253 let body = res.bytes().await;
254
255 let output = Output::decode(body).unwrap();
256
257 assert_eq!(output.result, "bar");
258 }
259}