rama_http/layer/
header_config.rs

1//! Extract a header config from a request or response and insert it into the [`Extensions`] of its [`Context`].
2//!
3//! [`Extensions`]: rama_core::context::Extensions
4//! [`Context`]: rama_core::Context
5//!
6//! # Example
7//!
8//! ```rust
9//! use rama_http::layer::header_config::{HeaderConfigLayer, HeaderConfigService};
10//! use rama_http::service::web::{WebService};
11//! use rama_http::{Body, Request, StatusCode, HeaderName};
12//! use rama_core::{Context, Service, Layer};
13//! use serde::Deserialize;
14//!
15//! #[derive(Debug, Deserialize, Clone)]
16//! struct Config {
17//!     s: String,
18//!     n: i32,
19//!     m: Option<i32>,
20//!     b: bool,
21//! }
22//!
23//! #[tokio::main]
24//! async fn main() {
25//!     let service = HeaderConfigLayer::<Config>::required(HeaderName::from_static("x-proxy-config"))
26//!         .layer(WebService::default()
27//!             .get("/", |ctx: Context<()>| async move {
28//!                 let cfg = ctx.get::<Config>().unwrap();
29//!                 assert_eq!(cfg.s, "E&G");
30//!                 assert_eq!(cfg.n, 1);
31//!                 assert!(cfg.m.is_none());
32//!                 assert!(cfg.b);
33//!             }),
34//!         );
35//!
36//!     let request = Request::builder()
37//!         .header("x-proxy-config", "s=E%26G&n=1&b=true")
38//!         .body(Body::empty())
39//!         .unwrap();
40//!
41//!     let resp = service.serve(Context::default(), request).await.unwrap();
42//!     assert_eq!(resp.status(), StatusCode::OK);
43//! }
44//! ```
45
46use crate::{header::AsHeaderName, HeaderName};
47use crate::{
48    utils::{HeaderValueErr, HeaderValueGetter},
49    Request,
50};
51use rama_core::{error::BoxError, Context, Layer, Service};
52use rama_utils::macros::define_inner_service_accessors;
53use serde::de::DeserializeOwned;
54use std::{fmt, marker::PhantomData};
55
56/// Extract a header config from a request or response without consuming it.
57pub fn extract_header_config<H, T, G>(request: &G, header_name: H) -> Result<T, HeaderValueErr>
58where
59    H: AsHeaderName + Copy,
60    T: DeserializeOwned + Clone + Send + Sync + 'static,
61    G: HeaderValueGetter,
62{
63    let value = request.header_str(header_name)?;
64    let config = serde_html_form::from_str::<T>(value)
65        .map_err(|_| HeaderValueErr::HeaderInvalid(header_name.as_str().to_owned()))?;
66    Ok(config)
67}
68
69/// A [`Service`] which extracts a header config from a request or response
70/// and inserts it into the [`Extensions`] of that object.
71///
72/// [`Extensions`]: rama_core::context::Extensions
73pub struct HeaderConfigService<T, S> {
74    inner: S,
75    header_name: HeaderName,
76    optional: bool,
77    _marker: PhantomData<fn() -> T>,
78}
79
80impl<T, S> HeaderConfigService<T, S> {
81    /// Create a new [`HeaderConfigService`].
82    ///
83    /// Alias for [`HeaderConfigService::required`] if `!optional`
84    /// and [`HeaderConfigService::optional`] if `optional`.
85    pub const fn new(inner: S, header_name: HeaderName, optional: bool) -> Self {
86        Self {
87            inner,
88            header_name,
89            optional,
90            _marker: PhantomData,
91        }
92    }
93
94    define_inner_service_accessors!();
95
96    /// Create a new [`HeaderConfigService`] with the given inner service
97    /// and header name, on which to extract the config,
98    /// and which will fail if the header is missing.
99    pub const fn required(inner: S, header_name: HeaderName) -> Self {
100        Self::new(inner, header_name, false)
101    }
102
103    /// Create a new [`HeaderConfigService`] with the given inner service
104    /// and header name, on which to extract the config,
105    /// and which will gracefully accept if the header is missing.
106    pub const fn optional(inner: S, header_name: HeaderName) -> Self {
107        Self::new(inner, header_name, true)
108    }
109}
110
111impl<T, S: fmt::Debug> fmt::Debug for HeaderConfigService<T, S> {
112    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
113        f.debug_struct("HeaderConfigService")
114            .field("inner", &self.inner)
115            .field("header_name", &self.header_name)
116            .field("optional", &self.optional)
117            .field(
118                "_marker",
119                &format_args!("{}", std::any::type_name::<fn() -> T>()),
120            )
121            .finish()
122    }
123}
124
125impl<T, S> Clone for HeaderConfigService<T, S>
126where
127    S: Clone,
128{
129    fn clone(&self) -> Self {
130        Self {
131            inner: self.inner.clone(),
132            header_name: self.header_name.clone(),
133            optional: self.optional,
134            _marker: PhantomData,
135        }
136    }
137}
138
139impl<T, S, State, Body, E> Service<State, Request<Body>> for HeaderConfigService<T, S>
140where
141    S: Service<State, Request<Body>, Error = E>,
142    T: DeserializeOwned + Clone + Send + Sync + 'static,
143    State: Clone + Send + Sync + 'static,
144    Body: Send + Sync + 'static,
145    E: Into<BoxError> + Send + Sync + 'static,
146{
147    type Response = S::Response;
148    type Error = BoxError;
149
150    async fn serve(
151        &self,
152        mut ctx: Context<State>,
153        request: Request<Body>,
154    ) -> Result<Self::Response, Self::Error> {
155        let config = match extract_header_config::<_, T, _>(&request, &self.header_name) {
156            Ok(config) => config,
157            Err(err) => {
158                if self.optional && matches!(err, crate::utils::HeaderValueErr::HeaderMissing(_)) {
159                    tracing::debug!(error = %err, "failed to extract header config");
160                    return self.inner.serve(ctx, request).await.map_err(Into::into);
161                } else {
162                    return Err(err.into());
163                }
164            }
165        };
166        ctx.insert(config);
167        self.inner.serve(ctx, request).await.map_err(Into::into)
168    }
169}
170
171/// Layer which extracts a header config for the given HeaderName
172/// from a request or response and inserts it into the [`Extensions`] of that object.
173///
174/// [`Extensions`]: rama_core::context::Extensions
175pub struct HeaderConfigLayer<T> {
176    header_name: HeaderName,
177    optional: bool,
178    _marker: PhantomData<fn() -> T>,
179}
180
181impl<T> fmt::Debug for HeaderConfigLayer<T> {
182    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
183        f.debug_struct("HeaderConfigLayer")
184            .field("header_name", &self.header_name)
185            .field("optional", &self.optional)
186            .field(
187                "_marker",
188                &format_args!("{}", std::any::type_name::<fn() -> T>()),
189            )
190            .finish()
191    }
192}
193
194impl<T> Clone for HeaderConfigLayer<T> {
195    fn clone(&self) -> Self {
196        Self {
197            header_name: self.header_name.clone(),
198            optional: self.optional,
199            _marker: PhantomData,
200        }
201    }
202}
203
204impl<T> HeaderConfigLayer<T> {
205    /// Create a new [`HeaderConfigLayer`] with the given header name,
206    /// on which to extract the config,
207    /// and which will fail if the header is missing.
208    pub fn required(header_name: HeaderName) -> Self {
209        Self {
210            header_name,
211            optional: false,
212            _marker: PhantomData,
213        }
214    }
215
216    /// Create a new [`HeaderConfigLayer`] with the given header name,
217    /// on which to extract the config,
218    /// and which will gracefully accept if the header is missing.
219    pub fn optional(header_name: HeaderName) -> Self {
220        Self {
221            header_name,
222            optional: true,
223            _marker: PhantomData,
224        }
225    }
226}
227
228impl<T, S> Layer<S> for HeaderConfigLayer<T> {
229    type Service = HeaderConfigService<T, S>;
230
231    fn layer(&self, inner: S) -> Self::Service {
232        HeaderConfigService::new(inner, self.header_name.clone(), self.optional)
233    }
234}
235
236#[cfg(test)]
237mod test {
238    use serde::Deserialize;
239
240    use crate::Method;
241
242    use super::*;
243
244    #[tokio::test]
245    async fn test_header_config_required_happy_path() {
246        let request = Request::builder()
247            .method(Method::GET)
248            .uri("https://www.example.com")
249            .header("x-proxy-config", "s=E%26G&n=1&b=true")
250            .body(())
251            .unwrap();
252
253        let inner_service =
254            rama_core::service::service_fn(|ctx: Context<()>, _req: Request<()>| async move {
255                let cfg: &Config = ctx.get().unwrap();
256                assert_eq!(cfg.s, "E&G");
257                assert_eq!(cfg.n, 1);
258                assert!(cfg.m.is_none());
259                assert!(cfg.b);
260
261                Ok::<_, std::convert::Infallible>(())
262            });
263
264        let service = HeaderConfigService::<Config, _>::required(
265            inner_service,
266            HeaderName::from_static("x-proxy-config"),
267        );
268
269        service.serve(Context::default(), request).await.unwrap();
270    }
271
272    #[tokio::test]
273    async fn test_header_config_optional_found() {
274        let request = Request::builder()
275            .method(Method::GET)
276            .uri("https://www.example.com")
277            .header("x-proxy-config", "s=E%26G&n=1&b=true")
278            .body(())
279            .unwrap();
280
281        let inner_service =
282            rama_core::service::service_fn(|ctx: Context<()>, _req: Request<()>| async move {
283                let cfg: &Config = ctx.get().unwrap();
284                assert_eq!(cfg.s, "E&G");
285                assert_eq!(cfg.n, 1);
286                assert!(cfg.m.is_none());
287                assert!(cfg.b);
288
289                Ok::<_, std::convert::Infallible>(())
290            });
291
292        let service = HeaderConfigService::<Config, _>::optional(
293            inner_service,
294            HeaderName::from_static("x-proxy-config"),
295        );
296
297        service.serve(Context::default(), request).await.unwrap();
298    }
299
300    #[tokio::test]
301    async fn test_header_config_optional_missing() {
302        let request = Request::builder()
303            .method(Method::GET)
304            .uri("https://www.example.com")
305            .body(())
306            .unwrap();
307
308        let inner_service =
309            rama_core::service::service_fn(|ctx: Context<()>, _req: Request<()>| async move {
310                assert!(ctx.get::<Config>().is_none());
311
312                Ok::<_, std::convert::Infallible>(())
313            });
314
315        let service = HeaderConfigService::<Config, _>::optional(
316            inner_service,
317            HeaderName::from_static("x-proxy-config"),
318        );
319
320        service.serve(Context::default(), request).await.unwrap();
321    }
322
323    #[tokio::test]
324    async fn test_header_config_required_missing_header() {
325        let request = Request::builder()
326            .method(Method::GET)
327            .uri("https://www.example.com")
328            .body(())
329            .unwrap();
330
331        let inner_service = rama_core::service::service_fn(|_: Request<()>| async move {
332            Ok::<_, std::convert::Infallible>(())
333        });
334
335        let service = HeaderConfigService::<Config, _>::required(
336            inner_service,
337            HeaderName::from_static("x-proxy-config"),
338        );
339
340        let result = service.serve(Context::default(), request).await;
341        assert!(result.is_err());
342    }
343
344    #[tokio::test]
345    async fn test_header_config_required_invalid_config() {
346        let request = Request::builder()
347            .method(Method::GET)
348            .uri("https://www.example.com")
349            .header("x-proxy-config", "s=bar&n=1&b=invalid")
350            .body(())
351            .unwrap();
352
353        let inner_service = rama_core::service::service_fn(|_: Request<()>| async move {
354            Ok::<_, std::convert::Infallible>(())
355        });
356
357        let service = HeaderConfigService::<Config, _>::required(
358            inner_service,
359            HeaderName::from_static("x-proxy-config"),
360        );
361
362        let result = service.serve(Context::default(), request).await;
363        assert!(result.is_err());
364    }
365
366    #[tokio::test]
367    async fn test_header_config_optional_invalid_config() {
368        let request = Request::builder()
369            .method(Method::GET)
370            .uri("https://www.example.com")
371            .header("x-proxy-config", "s=bar&n=1&b=invalid")
372            .body(())
373            .unwrap();
374
375        let inner_service = rama_core::service::service_fn(|_: Request<()>| async move {
376            Ok::<_, std::convert::Infallible>(())
377        });
378
379        let service = HeaderConfigService::<Config, _>::optional(
380            inner_service,
381            HeaderName::from_static("x-proxy-config"),
382        );
383
384        let result = service.serve(Context::default(), request).await;
385        assert!(result.is_err());
386    }
387
388    #[derive(Debug, Deserialize, Clone)]
389    struct Config {
390        s: String,
391        n: i32,
392        m: Option<i32>,
393        b: bool,
394    }
395}