rama_http/layer/
header_from_str_config.rs

1//! Extract a header config from a request or response and insert it into the [`Extensions`] of its [`Context`].
2//!
3//! [`Extensions`]: rama_core::context::Extensions
4//! [`Context`]: rama_core::Context
5//!
6//! # Example
7//!
8//! ```rust
9//! use rama_http::layer::header_from_str_config::HeaderFromStrConfigLayer;
10//! use rama_http::service::web::{WebService};
11//! use rama_http::{Body, Request, StatusCode, HeaderName};
12//! use rama_core::{Context, Service, Layer};
13//! use serde::Deserialize;
14//!
15//! #[tokio::main]
16//! async fn main() {
17//!     let service = HeaderFromStrConfigLayer::<String>::required(HeaderName::from_static("x-proxy-labels"))
18//!         .with_repeat(true)
19//!         .into_layer(WebService::default()
20//!             .get("/", async |ctx: Context<()>| {
21//!                 // For production-like code you should prefer a custom type
22//!                 // to avoid possible conflicts. Ideally these are also as
23//!                 // cheap as possible to allocate.
24//!                 let labels: &Vec<String> = ctx.get().unwrap();
25//!                 assert_eq!("a+b+c", labels.join("+"));
26//!             }),
27//!         );
28//!
29//!     let request = Request::builder()
30//!         .header("x-proxy-labels", "a, b")
31//!         .header("x-proxy-labels", "c")
32//!         .body(Body::empty())
33//!         .unwrap();
34//!
35//!     let resp = service.serve(Context::default(), request).await.unwrap();
36//!     assert_eq!(resp.status(), StatusCode::OK);
37//! }
38//! ```
39
40use crate::HeaderName;
41use crate::{
42    Request,
43    utils::{HeaderValueErr, HeaderValueGetter},
44};
45use rama_core::{Context, Layer, Service, error::BoxError};
46use rama_utils::macros::define_inner_service_accessors;
47use std::iter::FromIterator;
48use std::str::FromStr;
49use std::{fmt, marker::PhantomData};
50
51/// A [`Service`] which extracts a header CSV config from a request or response
52/// and inserts it into the [`Extensions`] of that object.
53///
54/// [`Extensions`]: rama_core::context::Extensions
55pub struct HeaderFromStrConfigService<T, S, C = Vec<T>> {
56    inner: S,
57    header_name: HeaderName,
58    optional: bool,
59    repeat: bool,
60    _marker: PhantomData<fn() -> (T, C)>,
61}
62
63impl<T, S, C> HeaderFromStrConfigService<T, S, C> {
64    define_inner_service_accessors!();
65
66    /// Create a new [`HeaderFromStrConfigService`] with the given inner service
67    /// and header name, on which to extract the config,
68    /// and which will fail if the header is missing.
69    pub const fn required(inner: S, header_name: HeaderName) -> Self {
70        Self {
71            inner,
72            header_name,
73            optional: false,
74            repeat: false,
75            _marker: PhantomData,
76        }
77    }
78
79    /// Create a new [`HeaderFromStrConfigService`] with the given inner service
80    /// and header name, on which to extract the config,
81    /// and which will gracefully accept if the header is missing.
82    pub const fn optional(inner: S, header_name: HeaderName) -> Self {
83        Self {
84            inner,
85            header_name,
86            optional: true,
87            repeat: false,
88            _marker: PhantomData,
89        }
90    }
91
92    /// Toggle repeat on/off. When repeat is enabled the
93    /// data config will be parsed and inserted as a container of type `C` (defaults to `Vec<T>`).
94    pub fn set_repeat(&mut self, repeat: bool) -> &mut Self {
95        self.repeat = repeat;
96        self
97    }
98
99    /// Toggle repeat on/off. When repeat is enabled the
100    /// data config will be parsed and inserted as a container of type `C` (defaults to `Vec<T>`).
101    pub fn with_repeat(mut self, repeat: bool) -> Self {
102        self.repeat = repeat;
103        self
104    }
105}
106
107impl<T, S, C> fmt::Debug for HeaderFromStrConfigService<T, S, C>
108where
109    S: fmt::Debug,
110{
111    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
112        f.debug_struct("HeaderFromStrConfigService")
113            .field("inner", &self.inner)
114            .field("header_name", &self.header_name)
115            .field("optional", &self.optional)
116            .field("repeat", &self.repeat)
117            .field(
118                "_marker",
119                &format_args!("{}", std::any::type_name::<fn() -> (T, C)>()),
120            )
121            .finish()
122    }
123}
124
125impl<T, S, C> Clone for HeaderFromStrConfigService<T, S, C>
126where
127    S: Clone,
128{
129    fn clone(&self) -> Self {
130        Self {
131            inner: self.inner.clone(),
132            header_name: self.header_name.clone(),
133            optional: self.optional,
134            repeat: self.repeat,
135            _marker: PhantomData,
136        }
137    }
138}
139
140impl<T, S, State, Body, E, C> Service<State, Request<Body>> for HeaderFromStrConfigService<T, S, C>
141where
142    S: Service<State, Request<Body>, Error = E>,
143    T: FromStr<Err: Into<BoxError> + Send + Sync + 'static> + Send + Sync + 'static + Clone,
144    C: FromIterator<T> + Send + Sync + 'static + Clone,
145    State: Clone + Send + Sync + 'static,
146    Body: Send + Sync + 'static,
147    E: Into<BoxError> + Send + Sync + 'static,
148{
149    type Response = S::Response;
150    type Error = BoxError;
151
152    async fn serve(
153        &self,
154        mut ctx: Context<State>,
155        request: Request<Body>,
156    ) -> Result<Self::Response, Self::Error> {
157        if self.repeat {
158            let headers = request.headers().get_all(&self.header_name);
159            let mut parsed_values = headers
160                .into_iter()
161                .flat_map(|value| {
162                    value.to_str().into_iter().flat_map(|string| {
163                        string
164                            .split(',')
165                            .filter_map(|x| match x.trim() {
166                                "" => None,
167                                y => Some(y),
168                            })
169                            .map(|x| x.parse::<T>().map_err(Into::into))
170                    })
171                })
172                .peekable();
173
174            if parsed_values.peek().is_none() {
175                if !self.optional {
176                    return Err(HeaderValueErr::HeaderMissing(self.header_name.to_string()).into());
177                }
178            } else {
179                let values = parsed_values.collect::<Result<C, _>>()?;
180                ctx.insert(values);
181            }
182        } else {
183            match request.header_str(&self.header_name) {
184                Ok(s) => {
185                    let cfg: T = s.parse().map_err(Into::into)?;
186                    ctx.insert(cfg);
187                }
188                Err(HeaderValueErr::HeaderMissing(_)) if self.optional => (),
189                Err(err) => {
190                    return Err(err.into());
191                }
192            }
193        }
194
195        self.inner.serve(ctx, request).await.map_err(Into::into)
196    }
197}
198
199/// Layer which extracts a header CSV config for the given HeaderName
200/// from a request or response and inserts it into the [`Extensions`] of that object.
201///
202/// [`Extensions`]: rama_core::context::Extensions
203pub struct HeaderFromStrConfigLayer<T, C = Vec<T>> {
204    header_name: HeaderName,
205    optional: bool,
206    repeat: bool,
207    _marker: PhantomData<fn() -> (T, C)>,
208}
209
210impl<T, C: fmt::Debug> fmt::Debug for HeaderFromStrConfigLayer<T, C> {
211    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
212        f.debug_struct("HeaderFromStrConfigLayer")
213            .field("header_name", &self.header_name)
214            .field("optional", &self.optional)
215            .field("repeat", &self.repeat)
216            .field(
217                "_marker",
218                &format_args!("{}", std::any::type_name::<fn() -> (T, C)>()),
219            )
220            .finish()
221    }
222}
223
224impl<T, C> Clone for HeaderFromStrConfigLayer<T, C> {
225    fn clone(&self) -> Self {
226        Self {
227            header_name: self.header_name.clone(),
228            optional: self.optional,
229            repeat: self.repeat,
230            _marker: PhantomData,
231        }
232    }
233}
234
235impl<T, C> HeaderFromStrConfigLayer<T, C> {
236    /// Create a new [`HeaderFromStrConfigLayer`] with the given header name,
237    /// on which to extract the config,
238    /// and which will fail if the header is missing.
239    pub fn required(header_name: HeaderName) -> Self {
240        Self {
241            header_name,
242            optional: false,
243            repeat: false,
244            _marker: PhantomData,
245        }
246    }
247
248    /// Create a new [`HeaderFromStrConfigLayer`] with the given header name,
249    /// on which to extract the config,
250    /// and which will gracefully accept if the header is missing.
251    pub fn optional(header_name: HeaderName) -> Self {
252        Self {
253            header_name,
254            optional: true,
255            repeat: false,
256            _marker: PhantomData,
257        }
258    }
259
260    /// Toggle repeat on/off. When repeat is enabled the
261    /// data config will be parsed and inserted as a container of type `C` (defaults to `Vec<T>`).
262    pub fn set_repeat(&mut self, repeat: bool) -> &mut Self {
263        self.repeat = repeat;
264        self
265    }
266
267    /// Toggle repeat on/off. When repeat is enabled the
268    /// data config will be parsed and inserted as a container of type `C` (defaults to `Vec<T>`).
269    pub fn with_repeat(mut self, repeat: bool) -> Self {
270        self.repeat = repeat;
271        self
272    }
273}
274
275impl<T, S, C> Layer<S> for HeaderFromStrConfigLayer<T, C> {
276    type Service = HeaderFromStrConfigService<T, S, C>;
277
278    fn layer(&self, inner: S) -> Self::Service {
279        HeaderFromStrConfigService {
280            inner,
281            header_name: self.header_name.clone(),
282            optional: self.optional,
283            repeat: self.repeat,
284            _marker: PhantomData,
285        }
286    }
287
288    fn into_layer(self, inner: S) -> Self::Service {
289        HeaderFromStrConfigService {
290            inner,
291            header_name: self.header_name,
292            optional: self.optional,
293            repeat: self.repeat,
294            _marker: PhantomData,
295        }
296    }
297}
298
299#[cfg(test)]
300mod test {
301    use super::*;
302    use crate::Method;
303    use std::collections::{HashSet, LinkedList};
304
305    #[tokio::test]
306    async fn test_header_config_required_happy_path() {
307        let request = Request::builder()
308            .method(Method::GET)
309            .uri("https://www.example.com")
310            .header("x-proxy-id", "42")
311            .body(())
312            .unwrap();
313
314        let inner_service =
315            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
316                let id: &usize = ctx.get().unwrap();
317                assert_eq!(*id, 42);
318
319                Ok::<_, std::convert::Infallible>(())
320            });
321
322        let service = HeaderFromStrConfigService::<usize, _>::required(
323            inner_service,
324            HeaderName::from_static("x-proxy-id"),
325        );
326
327        service.serve(Context::default(), request).await.unwrap();
328    }
329
330    #[tokio::test]
331    async fn test_header_config_required_repeat_happy_path() {
332        let request = Request::builder()
333            .method(Method::GET)
334            .uri("https://www.example.com")
335            .header("x-proxy-labels", "foo,bar ,baz, fin ")
336            .body(())
337            .unwrap();
338
339        let inner_service =
340            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
341                let labels: &Vec<String> = ctx.get().unwrap();
342                assert_eq!("foo+bar+baz+fin", labels.join("+"));
343
344                Ok::<_, std::convert::Infallible>(())
345            });
346
347        let service = HeaderFromStrConfigService::<String, _>::required(
348            inner_service,
349            HeaderName::from_static("x-proxy-labels"),
350        )
351        .with_repeat(true);
352
353        service.serve(Context::default(), request).await.unwrap();
354    }
355
356    #[tokio::test]
357    async fn test_header_config_required_repeat_custom_container() {
358        let request = Request::builder()
359            .method(Method::GET)
360            .uri("https://www.example.com")
361            .header("x-proxy-labels", "foo,bar,baz,foo")
362            .body(())
363            .unwrap();
364
365        let inner_service =
366            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
367                let labels: &HashSet<String> = ctx.get().unwrap();
368                assert_eq!(3, labels.len());
369                assert!(labels.contains("foo"));
370                assert!(labels.contains("bar"));
371                assert!(labels.contains("baz"));
372
373                Ok::<_, std::convert::Infallible>(())
374            });
375
376        let service = HeaderFromStrConfigService::<String, _, HashSet<String>>::required(
377            inner_service,
378            HeaderName::from_static("x-proxy-labels"),
379        )
380        .with_repeat(true);
381
382        service.serve(Context::default(), request).await.unwrap();
383    }
384
385    #[tokio::test]
386    async fn test_header_config_required_repeat_linked_list() {
387        let request = Request::builder()
388            .method(Method::GET)
389            .uri("https://www.example.com")
390            .header("x-proxy-labels", "foo,bar,baz")
391            .body(())
392            .unwrap();
393
394        let inner_service =
395            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
396                let labels: &LinkedList<String> = ctx.get().unwrap();
397                let mut iter = labels.iter();
398                assert_eq!(Some("foo"), iter.next().map(|x| x.as_str()));
399                assert_eq!(Some("bar"), iter.next().map(|x| x.as_str()));
400                assert_eq!(Some("baz"), iter.next().map(|x| x.as_str()));
401                assert_eq!(None, iter.next());
402
403                Ok::<_, std::convert::Infallible>(())
404            });
405
406        let service = HeaderFromStrConfigService::<String, _, LinkedList<String>>::required(
407            inner_service,
408            HeaderName::from_static("x-proxy-labels"),
409        )
410        .with_repeat(true);
411
412        service.serve(Context::default(), request).await.unwrap();
413    }
414
415    #[tokio::test]
416    async fn test_header_config_required_repeat_happy_path_multi_header() {
417        let request = Request::builder()
418            .method(Method::GET)
419            .uri("https://www.example.com")
420            .header("x-proxy-labels", "foo,bar ")
421            .header("x-Proxy-Labels", "baz ")
422            .header("X-PROXY-LABELS", " fin")
423            .body(())
424            .unwrap();
425
426        let inner_service =
427            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
428                let labels: &Vec<String> = ctx.get().unwrap();
429                assert_eq!("foo+bar+baz+fin", labels.join("+"));
430
431                Ok::<_, std::convert::Infallible>(())
432            });
433
434        let service = HeaderFromStrConfigService::<String, _>::required(
435            inner_service,
436            HeaderName::from_static("x-proxy-labels"),
437        )
438        .with_repeat(true);
439
440        service.serve(Context::default(), request).await.unwrap();
441    }
442
443    #[tokio::test]
444    async fn test_header_config_optional_found() {
445        let request = Request::builder()
446            .method(Method::GET)
447            .uri("https://www.example.com")
448            .header("x-proxy-id", "42")
449            .body(())
450            .unwrap();
451
452        let inner_service =
453            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
454                let id: usize = *ctx.get().unwrap();
455                assert_eq!(id, 42);
456
457                Ok::<_, std::convert::Infallible>(())
458            });
459
460        let service = HeaderFromStrConfigService::<usize, _>::optional(
461            inner_service,
462            HeaderName::from_static("x-proxy-id"),
463        );
464
465        service.serve(Context::default(), request).await.unwrap();
466    }
467
468    #[tokio::test]
469    async fn test_header_config_repeat_optional_found() {
470        let request = Request::builder()
471            .method(Method::GET)
472            .uri("https://www.example.com")
473            .header("x-proxy-labels", "foo,bar ,baz, fin ")
474            .body(())
475            .unwrap();
476
477        let inner_service =
478            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
479                let labels: &Vec<String> = ctx.get().unwrap();
480                assert_eq!("foo+bar+baz+fin", labels.join("+"));
481
482                Ok::<_, std::convert::Infallible>(())
483            });
484
485        let service = HeaderFromStrConfigService::<String, _>::optional(
486            inner_service,
487            HeaderName::from_static("x-proxy-labels"),
488        )
489        .with_repeat(true);
490
491        service.serve(Context::default(), request).await.unwrap();
492    }
493
494    #[tokio::test]
495    async fn test_header_config_optional_missing() {
496        let request = Request::builder()
497            .method(Method::GET)
498            .uri("https://www.example.com")
499            .body(())
500            .unwrap();
501
502        let inner_service =
503            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
504                assert!(ctx.get::<usize>().is_none());
505                Ok::<_, std::convert::Infallible>(())
506            });
507
508        let service = HeaderFromStrConfigService::<usize, _>::optional(
509            inner_service,
510            HeaderName::from_static("x-proxy-id"),
511        );
512
513        service.serve(Context::default(), request).await.unwrap();
514    }
515
516    #[tokio::test]
517    async fn test_header_config_repeat_optional_missing() {
518        let request = Request::builder()
519            .method(Method::GET)
520            .uri("https://www.example.com")
521            .body(())
522            .unwrap();
523
524        let inner_service =
525            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
526                assert!(ctx.get::<Vec<String>>().is_none());
527
528                Ok::<_, std::convert::Infallible>(())
529            });
530
531        let service = HeaderFromStrConfigService::<String, _>::optional(
532            inner_service,
533            HeaderName::from_static("x-proxy-labels"),
534        )
535        .with_repeat(true);
536
537        service.serve(Context::default(), request).await.unwrap();
538    }
539
540    #[tokio::test]
541    async fn test_header_config_required_missing_header() {
542        let request = Request::builder()
543            .method(Method::GET)
544            .uri("https://www.example.com")
545            .body(())
546            .unwrap();
547
548        let inner_service =
549            rama_core::service::service_fn(async |_ctx: Context<()>, _req: Request<()>| {
550                Ok::<_, std::convert::Infallible>(())
551            });
552
553        let service = HeaderFromStrConfigService::<usize, _>::required(
554            inner_service,
555            HeaderName::from_static("x-proxy-id"),
556        );
557
558        let result = service.serve(Context::default(), request).await;
559        assert!(result.is_err());
560    }
561
562    #[tokio::test]
563    async fn test_header_config_repeat_required_missing() {
564        let request = Request::builder()
565            .method(Method::GET)
566            .uri("https://www.example.com")
567            .body(())
568            .unwrap();
569
570        let inner_service =
571            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
572                assert!(ctx.get::<Vec<String>>().is_none());
573
574                Ok::<_, std::convert::Infallible>(())
575            });
576
577        let service = HeaderFromStrConfigService::<String, _>::required(
578            inner_service,
579            HeaderName::from_static("x-proxy-labels"),
580        )
581        .with_repeat(true);
582
583        let result = service.serve(Context::default(), request).await;
584        assert!(result.is_err());
585    }
586
587    #[tokio::test]
588    async fn test_header_config_required_invalid_config() {
589        let request = Request::builder()
590            .method(Method::GET)
591            .uri("https://www.example.com")
592            .header("x-proxy-id", "foo")
593            .body(())
594            .unwrap();
595
596        let inner_service =
597            rama_core::service::service_fn(async |_ctx: Context<()>, _req: Request<()>| {
598                Ok::<_, std::convert::Infallible>(())
599            });
600
601        let service = HeaderFromStrConfigService::<usize, _>::required(
602            inner_service,
603            HeaderName::from_static("x-proxy-id"),
604        );
605
606        let result = service.serve(Context::default(), request).await;
607        assert!(result.is_err());
608    }
609
610    #[tokio::test]
611    async fn test_header_config_repeat_required_invalid_config() {
612        let request = Request::builder()
613            .method(Method::GET)
614            .uri("https://www.example.com")
615            .header("x-proxy-labels", "42,foo")
616            .body(())
617            .unwrap();
618
619        let inner_service =
620            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
621                assert!(ctx.get::<Vec<String>>().is_none());
622
623                Ok::<_, std::convert::Infallible>(())
624            });
625
626        let service = HeaderFromStrConfigService::<usize, _>::required(
627            inner_service,
628            HeaderName::from_static("x-proxy-labels"),
629        )
630        .with_repeat(true);
631
632        let result = service.serve(Context::default(), request).await;
633        assert!(result.is_err());
634    }
635
636    #[tokio::test]
637    async fn test_header_config_optional_invalid_config() {
638        let request = Request::builder()
639            .method(Method::GET)
640            .uri("https://www.example.com")
641            .header("x-proxy-id", "foo")
642            .body(())
643            .unwrap();
644
645        let inner_service =
646            rama_core::service::service_fn(async |_ctx: Context<()>, _req: Request<()>| {
647                Ok::<_, std::convert::Infallible>(())
648            });
649
650        let service = HeaderFromStrConfigService::<usize, _>::optional(
651            inner_service,
652            HeaderName::from_static("x-proxy-id"),
653        );
654
655        let result = service.serve(Context::default(), request).await;
656        assert!(result.is_err());
657    }
658
659    #[tokio::test]
660    async fn test_header_config_repeat_optional_invalid_config() {
661        let request = Request::builder()
662            .method(Method::GET)
663            .uri("https://www.example.com")
664            .header("x-proxy-labels", "42,foo")
665            .body(())
666            .unwrap();
667
668        let inner_service =
669            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
670                assert!(ctx.get::<Vec<String>>().is_none());
671
672                Ok::<_, std::convert::Infallible>(())
673            });
674
675        let service = HeaderFromStrConfigService::<usize, _>::optional(
676            inner_service,
677            HeaderName::from_static("x-proxy-labels"),
678        )
679        .with_repeat(true);
680
681        let result = service.serve(Context::default(), request).await;
682        assert!(result.is_err());
683    }
684}