axum_extra/
protobuf.rs

1//! Protocol Buffer extractor and response.
2
3use axum_core::__composite_rejection as composite_rejection;
4use axum_core::__define_rejection as define_rejection;
5use axum_core::{
6    extract::{rejection::BytesRejection, FromRequest, Request},
7    response::{IntoResponse, Response},
8    RequestExt,
9};
10use bytes::BytesMut;
11use http::StatusCode;
12use http_body_util::BodyExt;
13use prost::Message;
14
15/// A Protocol Buffer message extractor and response.
16///
17/// This can be used both as an extractor and as a response.
18///
19/// # As extractor
20///
21/// When used as an extractor, it can decode request bodies into some type that
22/// implements [`prost::Message`]. The request will be rejected (and a [`ProtobufRejection`] will
23/// be returned) if:
24///
25/// - The body couldn't be decoded into the target Protocol Buffer message type.
26/// - Buffering the request body fails.
27///
28/// See [`ProtobufRejection`] for more details.
29///
30/// The extractor does not expect a `Content-Type` header to be present in the request.
31///
32/// # Extractor example
33///
34/// ```rust,no_run
35/// use axum::{routing::post, Router};
36/// use axum_extra::protobuf::Protobuf;
37///
38/// #[derive(prost::Message)]
39/// struct CreateUser {
40///     #[prost(string, tag="1")]
41///     email: String,
42///     #[prost(string, tag="2")]
43///     password: String,
44/// }
45///
46/// async fn create_user(Protobuf(payload): Protobuf<CreateUser>) {
47///     // payload is `CreateUser`
48/// }
49///
50/// let app = Router::new().route("/users", post(create_user));
51/// # let _: Router = app;
52/// ```
53///
54/// # As response
55///
56/// When used as a response, it can encode any type that implements [`prost::Message`] to
57/// a newly allocated buffer.
58///
59/// If no `Content-Type` header is set, the `Content-Type: application/octet-stream` header
60/// will be used automatically.
61///
62/// # Response example
63///
64/// ```
65/// use axum::{
66///     extract::Path,
67///     routing::get,
68///     Router,
69/// };
70/// use axum_extra::protobuf::Protobuf;
71///
72/// #[derive(prost::Message)]
73/// struct User {
74///     #[prost(string, tag="1")]
75///     username: String,
76/// }
77///
78/// async fn get_user(Path(user_id) : Path<String>) -> Protobuf<User> {
79///     let user = find_user(user_id).await;
80///     Protobuf(user)
81/// }
82///
83/// async fn find_user(user_id: String) -> User {
84///     // ...
85///     # unimplemented!()
86/// }
87///
88/// let app = Router::new().route("/users/{id}", get(get_user));
89/// # let _: Router = app;
90/// ```
91#[derive(Debug, Clone, Copy, Default)]
92#[cfg_attr(docsrs, doc(cfg(feature = "protobuf")))]
93#[must_use]
94pub struct Protobuf<T>(pub T);
95
96impl<T, S> FromRequest<S> for Protobuf<T>
97where
98    T: Message + Default,
99    S: Send + Sync,
100{
101    type Rejection = ProtobufRejection;
102
103    async fn from_request(req: Request, _: &S) -> Result<Self, Self::Rejection> {
104        let mut buf = req
105            .into_limited_body()
106            .collect()
107            .await
108            .map_err(ProtobufDecodeError)?
109            .aggregate();
110
111        match T::decode(&mut buf) {
112            Ok(value) => Ok(Protobuf(value)),
113            Err(err) => Err(ProtobufDecodeError::from_err(err).into()),
114        }
115    }
116}
117
118axum_core::__impl_deref!(Protobuf);
119
120impl<T> From<T> for Protobuf<T> {
121    fn from(inner: T) -> Self {
122        Self(inner)
123    }
124}
125
126impl<T> IntoResponse for Protobuf<T>
127where
128    T: Message + Default,
129{
130    fn into_response(self) -> Response {
131        let mut buf = BytesMut::with_capacity(self.0.encoded_len());
132        match &self.0.encode(&mut buf) {
133            Ok(()) => buf.into_response(),
134            Err(err) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response(),
135        }
136    }
137}
138
139define_rejection! {
140    #[status = UNPROCESSABLE_ENTITY]
141    #[body = "Failed to decode the body"]
142    /// Rejection type for [`Protobuf`].
143    ///
144    /// This rejection is used if the request body couldn't be decoded into the target type.
145    pub struct ProtobufDecodeError(Error);
146}
147
148composite_rejection! {
149    /// Rejection used for [`Protobuf`].
150    ///
151    /// Contains one variant for each way the [`Protobuf`] extractor
152    /// can fail.
153    pub enum ProtobufRejection {
154        ProtobufDecodeError,
155        BytesRejection,
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use crate::test_helpers::*;
163    use axum::{routing::post, Router};
164
165    #[tokio::test]
166    async fn decode_body() {
167        #[derive(prost::Message)]
168        struct Input {
169            #[prost(string, tag = "1")]
170            foo: String,
171        }
172
173        let app = Router::new().route(
174            "/",
175            post(|input: Protobuf<Input>| async move { input.foo.to_owned() }),
176        );
177
178        let input = Input {
179            foo: "bar".to_owned(),
180        };
181
182        let client = TestClient::new(app);
183        let res = client.post("/").body(input.encode_to_vec()).await;
184
185        let body = res.text().await;
186
187        assert_eq!(body, "bar");
188    }
189
190    #[tokio::test]
191    async fn prost_decode_error() {
192        #[derive(prost::Message)]
193        struct Input {
194            #[prost(string, tag = "1")]
195            foo: String,
196        }
197
198        #[derive(prost::Message)]
199        struct Expected {
200            #[prost(int32, tag = "1")]
201            test: i32,
202        }
203
204        let app = Router::new().route("/", post(|_: Protobuf<Expected>| async {}));
205
206        let input = Input {
207            foo: "bar".to_owned(),
208        };
209
210        let client = TestClient::new(app);
211        let res = client.post("/").body(input.encode_to_vec()).await;
212
213        assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY);
214    }
215
216    #[tokio::test]
217    async fn encode_body() {
218        #[derive(prost::Message)]
219        struct Input {
220            #[prost(string, tag = "1")]
221            foo: String,
222        }
223
224        #[derive(prost::Message)]
225        struct Output {
226            #[prost(string, tag = "1")]
227            result: String,
228        }
229
230        #[axum::debug_handler]
231        async fn handler(input: Protobuf<Input>) -> Protobuf<Output> {
232            let output = Output {
233                result: input.foo.to_owned(),
234            };
235
236            Protobuf(output)
237        }
238
239        let app = Router::new().route("/", post(handler));
240
241        let input = Input {
242            foo: "bar".to_owned(),
243        };
244
245        let client = TestClient::new(app);
246        let res = client.post("/").body(input.encode_to_vec()).await;
247
248        assert_eq!(
249            res.headers()["content-type"],
250            mime::APPLICATION_OCTET_STREAM.as_ref()
251        );
252
253        let body = res.bytes().await;
254
255        let output = Output::decode(body).unwrap();
256
257        assert_eq!(output.result, "bar");
258    }
259}