rama_http/layer/
header_config.rs

1//! Extract a header config from a request or response and insert it into the [`Extensions`] of its [`Context`].
2//!
3//! [`Extensions`]: rama_core::context::Extensions
4//! [`Context`]: rama_core::Context
5//!
6//! # Example
7//!
8//! ```rust
9//! use rama_http::layer::header_config::{HeaderConfigLayer, HeaderConfigService};
10//! use rama_http::service::web::{WebService};
11//! use rama_http::{Body, Request, StatusCode, HeaderName};
12//! use rama_core::{Context, Service, Layer};
13//! use serde::Deserialize;
14//!
15//! #[derive(Debug, Deserialize, Clone)]
16//! struct Config {
17//!     s: String,
18//!     n: i32,
19//!     m: Option<i32>,
20//!     b: bool,
21//! }
22//!
23//! #[tokio::main]
24//! async fn main() {
25//!     let service = HeaderConfigLayer::<Config>::required(HeaderName::from_static("x-proxy-config"))
26//!         .into_layer(WebService::default()
27//!             .get("/", async |ctx: Context<()>| {
28//!                 let cfg = ctx.get::<Config>().unwrap();
29//!                 assert_eq!(cfg.s, "E&G");
30//!                 assert_eq!(cfg.n, 1);
31//!                 assert!(cfg.m.is_none());
32//!                 assert!(cfg.b);
33//!             }),
34//!         );
35//!
36//!     let request = Request::builder()
37//!         .header("x-proxy-config", "s=E%26G&n=1&b=true")
38//!         .body(Body::empty())
39//!         .unwrap();
40//!
41//!     let resp = service.serve(Context::default(), request).await.unwrap();
42//!     assert_eq!(resp.status(), StatusCode::OK);
43//! }
44//! ```
45
46use crate::{HeaderName, header::AsHeaderName};
47use crate::{
48    Request,
49    utils::{HeaderValueErr, HeaderValueGetter},
50};
51use rama_core::{Context, Layer, Service, error::BoxError};
52use rama_utils::macros::define_inner_service_accessors;
53use serde::de::DeserializeOwned;
54use std::{fmt, marker::PhantomData};
55
56/// Extract a header config from a request or response without consuming it.
57pub fn extract_header_config<H, T, G>(request: &G, header_name: H) -> Result<T, HeaderValueErr>
58where
59    H: AsHeaderName + Copy,
60    T: DeserializeOwned + Clone + Send + Sync + 'static,
61    G: HeaderValueGetter,
62{
63    let value = request.header_str(header_name)?;
64    let config = serde_html_form::from_str::<T>(value)
65        .map_err(|_| HeaderValueErr::HeaderInvalid(header_name.as_str().to_owned()))?;
66    Ok(config)
67}
68
69/// A [`Service`] which extracts a header config from a request or response
70/// and inserts it into the [`Extensions`] of that object.
71///
72/// [`Extensions`]: rama_core::context::Extensions
73pub struct HeaderConfigService<T, S> {
74    inner: S,
75    header_name: HeaderName,
76    optional: bool,
77    _marker: PhantomData<fn() -> T>,
78}
79
80impl<T, S> HeaderConfigService<T, S> {
81    /// Create a new [`HeaderConfigService`].
82    ///
83    /// Alias for [`HeaderConfigService::required`] if `!optional`
84    /// and [`HeaderConfigService::optional`] if `optional`.
85    pub const fn new(inner: S, header_name: HeaderName, optional: bool) -> Self {
86        Self {
87            inner,
88            header_name,
89            optional,
90            _marker: PhantomData,
91        }
92    }
93
94    define_inner_service_accessors!();
95
96    /// Create a new [`HeaderConfigService`] with the given inner service
97    /// and header name, on which to extract the config,
98    /// and which will fail if the header is missing.
99    pub const fn required(inner: S, header_name: HeaderName) -> Self {
100        Self::new(inner, header_name, false)
101    }
102
103    /// Create a new [`HeaderConfigService`] with the given inner service
104    /// and header name, on which to extract the config,
105    /// and which will gracefully accept if the header is missing.
106    pub const fn optional(inner: S, header_name: HeaderName) -> Self {
107        Self::new(inner, header_name, true)
108    }
109}
110
111impl<T, S: fmt::Debug> fmt::Debug for HeaderConfigService<T, S> {
112    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
113        f.debug_struct("HeaderConfigService")
114            .field("inner", &self.inner)
115            .field("header_name", &self.header_name)
116            .field("optional", &self.optional)
117            .field(
118                "_marker",
119                &format_args!("{}", std::any::type_name::<fn() -> T>()),
120            )
121            .finish()
122    }
123}
124
125impl<T, S> Clone for HeaderConfigService<T, S>
126where
127    S: Clone,
128{
129    fn clone(&self) -> Self {
130        Self {
131            inner: self.inner.clone(),
132            header_name: self.header_name.clone(),
133            optional: self.optional,
134            _marker: PhantomData,
135        }
136    }
137}
138
139impl<T, S, State, Body, E> Service<State, Request<Body>> for HeaderConfigService<T, S>
140where
141    S: Service<State, Request<Body>, Error = E>,
142    T: DeserializeOwned + Clone + Send + Sync + 'static,
143    State: Clone + Send + Sync + 'static,
144    Body: Send + Sync + 'static,
145    E: Into<BoxError> + Send + Sync + 'static,
146{
147    type Response = S::Response;
148    type Error = BoxError;
149
150    async fn serve(
151        &self,
152        mut ctx: Context<State>,
153        request: Request<Body>,
154    ) -> Result<Self::Response, Self::Error> {
155        let config = match extract_header_config::<_, T, _>(&request, &self.header_name) {
156            Ok(config) => config,
157            Err(err) => {
158                if self.optional && matches!(err, crate::utils::HeaderValueErr::HeaderMissing(_)) {
159                    tracing::debug!(error = %err, "failed to extract header config");
160                    return self.inner.serve(ctx, request).await.map_err(Into::into);
161                } else {
162                    return Err(err.into());
163                }
164            }
165        };
166        ctx.insert(config);
167        self.inner.serve(ctx, request).await.map_err(Into::into)
168    }
169}
170
171/// Layer which extracts a header config for the given HeaderName
172/// from a request or response and inserts it into the [`Extensions`] of that object.
173///
174/// [`Extensions`]: rama_core::context::Extensions
175pub struct HeaderConfigLayer<T> {
176    header_name: HeaderName,
177    optional: bool,
178    _marker: PhantomData<fn() -> T>,
179}
180
181impl<T> fmt::Debug for HeaderConfigLayer<T> {
182    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
183        f.debug_struct("HeaderConfigLayer")
184            .field("header_name", &self.header_name)
185            .field("optional", &self.optional)
186            .field(
187                "_marker",
188                &format_args!("{}", std::any::type_name::<fn() -> T>()),
189            )
190            .finish()
191    }
192}
193
194impl<T> Clone for HeaderConfigLayer<T> {
195    fn clone(&self) -> Self {
196        Self {
197            header_name: self.header_name.clone(),
198            optional: self.optional,
199            _marker: PhantomData,
200        }
201    }
202}
203
204impl<T> HeaderConfigLayer<T> {
205    /// Create a new [`HeaderConfigLayer`] with the given header name,
206    /// on which to extract the config,
207    /// and which will fail if the header is missing.
208    pub fn required(header_name: HeaderName) -> Self {
209        Self {
210            header_name,
211            optional: false,
212            _marker: PhantomData,
213        }
214    }
215
216    /// Create a new [`HeaderConfigLayer`] with the given header name,
217    /// on which to extract the config,
218    /// and which will gracefully accept if the header is missing.
219    pub fn optional(header_name: HeaderName) -> Self {
220        Self {
221            header_name,
222            optional: true,
223            _marker: PhantomData,
224        }
225    }
226}
227
228impl<T, S> Layer<S> for HeaderConfigLayer<T> {
229    type Service = HeaderConfigService<T, S>;
230
231    fn layer(&self, inner: S) -> Self::Service {
232        HeaderConfigService::new(inner, self.header_name.clone(), self.optional)
233    }
234
235    fn into_layer(self, inner: S) -> Self::Service {
236        HeaderConfigService::new(inner, self.header_name, self.optional)
237    }
238}
239
240#[cfg(test)]
241mod test {
242    use serde::Deserialize;
243
244    use crate::Method;
245
246    use super::*;
247
248    #[tokio::test]
249    async fn test_header_config_required_happy_path() {
250        let request = Request::builder()
251            .method(Method::GET)
252            .uri("https://www.example.com")
253            .header("x-proxy-config", "s=E%26G&n=1&b=true")
254            .body(())
255            .unwrap();
256
257        let inner_service =
258            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
259                let cfg: &Config = ctx.get().unwrap();
260                assert_eq!(cfg.s, "E&G");
261                assert_eq!(cfg.n, 1);
262                assert!(cfg.m.is_none());
263                assert!(cfg.b);
264
265                Ok::<_, std::convert::Infallible>(())
266            });
267
268        let service = HeaderConfigService::<Config, _>::required(
269            inner_service,
270            HeaderName::from_static("x-proxy-config"),
271        );
272
273        service.serve(Context::default(), request).await.unwrap();
274    }
275
276    #[tokio::test]
277    async fn test_header_config_optional_found() {
278        let request = Request::builder()
279            .method(Method::GET)
280            .uri("https://www.example.com")
281            .header("x-proxy-config", "s=E%26G&n=1&b=true")
282            .body(())
283            .unwrap();
284
285        let inner_service =
286            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
287                let cfg: &Config = ctx.get().unwrap();
288                assert_eq!(cfg.s, "E&G");
289                assert_eq!(cfg.n, 1);
290                assert!(cfg.m.is_none());
291                assert!(cfg.b);
292
293                Ok::<_, std::convert::Infallible>(())
294            });
295
296        let service = HeaderConfigService::<Config, _>::optional(
297            inner_service,
298            HeaderName::from_static("x-proxy-config"),
299        );
300
301        service.serve(Context::default(), request).await.unwrap();
302    }
303
304    #[tokio::test]
305    async fn test_header_config_optional_missing() {
306        let request = Request::builder()
307            .method(Method::GET)
308            .uri("https://www.example.com")
309            .body(())
310            .unwrap();
311
312        let inner_service =
313            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
314                assert!(ctx.get::<Config>().is_none());
315
316                Ok::<_, std::convert::Infallible>(())
317            });
318
319        let service = HeaderConfigService::<Config, _>::optional(
320            inner_service,
321            HeaderName::from_static("x-proxy-config"),
322        );
323
324        service.serve(Context::default(), request).await.unwrap();
325    }
326
327    #[tokio::test]
328    async fn test_header_config_required_missing_header() {
329        let request = Request::builder()
330            .method(Method::GET)
331            .uri("https://www.example.com")
332            .body(())
333            .unwrap();
334
335        let inner_service = rama_core::service::service_fn(async |_: Request<()>| {
336            Ok::<_, std::convert::Infallible>(())
337        });
338
339        let service = HeaderConfigService::<Config, _>::required(
340            inner_service,
341            HeaderName::from_static("x-proxy-config"),
342        );
343
344        let result = service.serve(Context::default(), request).await;
345        assert!(result.is_err());
346    }
347
348    #[tokio::test]
349    async fn test_header_config_required_invalid_config() {
350        let request = Request::builder()
351            .method(Method::GET)
352            .uri("https://www.example.com")
353            .header("x-proxy-config", "s=bar&n=1&b=invalid")
354            .body(())
355            .unwrap();
356
357        let inner_service = rama_core::service::service_fn(async |_: Request<()>| {
358            Ok::<_, std::convert::Infallible>(())
359        });
360
361        let service = HeaderConfigService::<Config, _>::required(
362            inner_service,
363            HeaderName::from_static("x-proxy-config"),
364        );
365
366        let result = service.serve(Context::default(), request).await;
367        assert!(result.is_err());
368    }
369
370    #[tokio::test]
371    async fn test_header_config_optional_invalid_config() {
372        let request = Request::builder()
373            .method(Method::GET)
374            .uri("https://www.example.com")
375            .header("x-proxy-config", "s=bar&n=1&b=invalid")
376            .body(())
377            .unwrap();
378
379        let inner_service = rama_core::service::service_fn(async |_: Request<()>| {
380            Ok::<_, std::convert::Infallible>(())
381        });
382
383        let service = HeaderConfigService::<Config, _>::optional(
384            inner_service,
385            HeaderName::from_static("x-proxy-config"),
386        );
387
388        let result = service.serve(Context::default(), request).await;
389        assert!(result.is_err());
390    }
391
392    #[derive(Debug, Deserialize, Clone)]
393    struct Config {
394        s: String,
395        n: i32,
396        m: Option<i32>,
397        b: bool,
398    }
399}