rama_http/layer/
map_request_body.rs

1//! Apply a transformation to the request body.
2//!
3//! # Example
4//!
5//! ```
6//! use rama_http::{Body, Request, Response};
7//! use rama_http::dep::http_body;
8//! use bytes::Bytes;
9//! use std::convert::Infallible;
10//! use std::{pin::Pin, task::{ready, Context, Poll}};
11//! use rama_core::{Layer, Service, context};
12//! use rama_core::service::service_fn;
13//! use rama_http::layer::map_request_body::MapRequestBodyLayer;
14//! use rama_core::error::BoxError;
15//!
16//! // A wrapper for a `Full<Bytes>`
17//! struct BodyWrapper {
18//!     inner: Body,
19//! }
20//!
21//! impl BodyWrapper {
22//!     fn new(inner: Body) -> Self {
23//!         Self { inner }
24//!     }
25//! }
26//!
27//! impl http_body::Body for BodyWrapper {
28//!     // ...
29//!     # type Data = Bytes;
30//!     # type Error = BoxError;
31//!     # fn poll_frame(
32//!     #     self: Pin<&mut Self>,
33//!     #     cx: &mut Context<'_>
34//!     # ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> { unimplemented!() }
35//!     # fn is_end_stream(&self) -> bool { unimplemented!() }
36//!     # fn size_hint(&self) -> http_body::SizeHint { unimplemented!() }
37//! }
38//!
39//! async fn handle<B>(_: Request<B>) -> Result<Response, Infallible> {
40//!     // ...
41//!     # Ok(Response::new(Body::default()))
42//! }
43//!
44//! # #[tokio::main]
45//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
46//! let mut svc = (
47//!     // Wrap response bodies in `BodyWrapper`
48//!     MapRequestBodyLayer::new(BodyWrapper::new),
49//! ).into_layer(service_fn(handle));
50//!
51//! // Call the service
52//! let request = Request::new(Body::default());
53//!
54//! svc.serve(context::Context::default(), request).await?;
55//! # Ok(())
56//! # }
57//! ```
58
59use crate::{Request, Response};
60use rama_core::{Context, Layer, Service};
61use rama_utils::macros::define_inner_service_accessors;
62use std::fmt;
63
64/// Apply a transformation to the request body.
65///
66/// See the [module docs](crate::layer::map_request_body) for an example.
67#[derive(Clone)]
68pub struct MapRequestBodyLayer<F> {
69    f: F,
70}
71
72impl<F> MapRequestBodyLayer<F> {
73    /// Create a new [`MapRequestBodyLayer`].
74    ///
75    /// `F` is expected to be a function that takes a body and returns another body.
76    pub const fn new(f: F) -> Self {
77        Self { f }
78    }
79}
80
81impl<S, F> Layer<S> for MapRequestBodyLayer<F>
82where
83    F: Clone,
84{
85    type Service = MapRequestBody<S, F>;
86
87    fn layer(&self, inner: S) -> Self::Service {
88        MapRequestBody::new(inner, self.f.clone())
89    }
90
91    fn into_layer(self, inner: S) -> Self::Service {
92        MapRequestBody::new(inner, self.f)
93    }
94}
95
96impl<F> fmt::Debug for MapRequestBodyLayer<F> {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        f.debug_struct("MapRequestBodyLayer")
99            .field("f", &std::any::type_name::<F>())
100            .finish()
101    }
102}
103
104/// Apply a transformation to the request body.
105///
106/// See the [module docs](crate::layer::map_request_body) for an example.
107#[derive(Clone)]
108pub struct MapRequestBody<S, F> {
109    inner: S,
110    f: F,
111}
112
113impl<S, F> MapRequestBody<S, F> {
114    /// Create a new [`MapRequestBody`].
115    ///
116    /// `F` is expected to be a function that takes a body and returns another body.
117    pub const fn new(service: S, f: F) -> Self {
118        Self { inner: service, f }
119    }
120
121    define_inner_service_accessors!();
122}
123
124impl<F, S, State, ReqBody, ResBody, NewReqBody> Service<State, Request<ReqBody>>
125    for MapRequestBody<S, F>
126where
127    S: Service<State, Request<NewReqBody>, Response = Response<ResBody>>,
128    State: Clone + Send + Sync + 'static,
129    ReqBody: Send + 'static,
130    NewReqBody: Send + 'static,
131    ResBody: Send + Sync + 'static,
132    F: Fn(ReqBody) -> NewReqBody + Send + Sync + 'static,
133{
134    type Response = S::Response;
135    type Error = S::Error;
136
137    async fn serve(
138        &self,
139        ctx: Context<State>,
140        req: Request<ReqBody>,
141    ) -> Result<Self::Response, Self::Error> {
142        let req = req.map(&self.f);
143        self.inner.serve(ctx, req).await
144    }
145}
146
147impl<S, F> fmt::Debug for MapRequestBody<S, F>
148where
149    S: fmt::Debug,
150{
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        f.debug_struct("MapRequestBody")
153            .field("inner", &self.inner)
154            .field("f", &std::any::type_name::<F>())
155            .finish()
156    }
157}