rama_http/layer/
header_option_value.rs

1//! Similar to [`super::header_config::HeaderConfigLayer`],
2//! but storing the [`Default`] value of type `T` in case
3//! the header with the given [`HeaderName`] is present
4//! and has a bool-like value.
5
6use crate::{utils::HeaderValueGetter, HeaderName, Request};
7use rama_core::{
8    error::{BoxError, ErrorExt, OpaqueError},
9    Context, Layer, Service,
10};
11use rama_utils::macros::define_inner_service_accessors;
12use std::{fmt, marker::PhantomData};
13
14/// A [`Service`] which stores the [`Default`] value of type `T` in case
15/// the header with the given [`HeaderName`] is present
16/// and has a bool-like value.
17pub struct HeaderOptionValueService<T, S> {
18    inner: S,
19    header_name: HeaderName,
20    optional: bool,
21    _marker: PhantomData<fn() -> T>,
22}
23
24impl<T, S> HeaderOptionValueService<T, S> {
25    /// Create a new [`HeaderOptionValueService`].
26    ///
27    /// Alias for [`HeaderOptionValueService::required`] if `!optional`
28    /// and [`HeaderOptionValueService::optional`] if `optional`.
29    pub const fn new(inner: S, header_name: HeaderName, optional: bool) -> Self {
30        Self {
31            inner,
32            header_name,
33            optional,
34            _marker: PhantomData,
35        }
36    }
37
38    define_inner_service_accessors!();
39
40    /// Create a new [`HeaderOptionValueService`] with the given inner service
41    /// and header name, on which optionally create the value,
42    /// and which will fail if the header is missing.
43    pub const fn required(inner: S, header_name: HeaderName) -> Self {
44        Self::new(inner, header_name, false)
45    }
46
47    /// Create a new [`HeaderOptionValueService`] with the given inner service
48    /// and header name, on which optionally create the value,
49    /// and which will gracefully accept if the header is missing.
50    pub const fn optional(inner: S, header_name: HeaderName) -> Self {
51        Self::new(inner, header_name, true)
52    }
53}
54
55impl<T, S: fmt::Debug> fmt::Debug for HeaderOptionValueService<T, S> {
56    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57        f.debug_struct("HeaderOptionValueService")
58            .field("inner", &self.inner)
59            .field("header_name", &self.header_name)
60            .field("optional", &self.optional)
61            .field(
62                "_marker",
63                &format_args!("{}", std::any::type_name::<fn() -> T>()),
64            )
65            .finish()
66    }
67}
68
69impl<T, S> Clone for HeaderOptionValueService<T, S>
70where
71    S: Clone,
72{
73    fn clone(&self) -> Self {
74        Self {
75            inner: self.inner.clone(),
76            header_name: self.header_name.clone(),
77            optional: self.optional,
78            _marker: PhantomData,
79        }
80    }
81}
82
83impl<T, S, State, Body, E> Service<State, Request<Body>> for HeaderOptionValueService<T, S>
84where
85    S: Service<State, Request<Body>, Error = E>,
86    T: Default + Clone + Send + Sync + 'static,
87    State: Clone + Send + Sync + 'static,
88    Body: Send + Sync + 'static,
89    E: Into<BoxError> + Send + Sync + 'static,
90{
91    type Response = S::Response;
92    type Error = BoxError;
93
94    async fn serve(
95        &self,
96        mut ctx: Context<State>,
97        request: Request<Body>,
98    ) -> Result<Self::Response, Self::Error> {
99        match request.header_str(&self.header_name) {
100            Ok(str_value) => {
101                let str_value = str_value.trim();
102                if str_value == "1" || str_value.eq_ignore_ascii_case("true") {
103                    ctx.insert(T::default());
104                } else if str_value != "0" && !str_value.eq_ignore_ascii_case("false") {
105                    return Err(OpaqueError::from_display(format!(
106                        "invalid '{}' header option: '{}'",
107                        self.header_name, str_value
108                    ))
109                    .into_boxed());
110                }
111            }
112            Err(err) => {
113                if self.optional && matches!(err, crate::utils::HeaderValueErr::HeaderMissing(_)) {
114                    tracing::debug!(
115                        error = %err,
116                        header_name = %self.header_name,
117                        "failed to determine header option",
118                    );
119                    return self.inner.serve(ctx, request).await.map_err(Into::into);
120                } else {
121                    return Err(err
122                        .with_context(|| format!("determine '{}' header option", self.header_name))
123                        .into_boxed());
124                }
125            }
126        };
127        self.inner.serve(ctx, request).await.map_err(Into::into)
128    }
129}
130
131/// Layer which stores the [`Default`] value of type `T` in case
132/// the header with the given [`HeaderName`] is present
133/// and has a bool-like value.
134pub struct HeaderOptionValueLayer<T> {
135    header_name: HeaderName,
136    optional: bool,
137    _marker: PhantomData<fn() -> T>,
138}
139
140impl<T> fmt::Debug for HeaderOptionValueLayer<T> {
141    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
142        f.debug_struct("HeaderOptionValueLayer")
143            .field("header_name", &self.header_name)
144            .field("optional", &self.optional)
145            .field(
146                "_marker",
147                &format_args!("{}", std::any::type_name::<fn() -> T>()),
148            )
149            .finish()
150    }
151}
152
153impl<T> Clone for HeaderOptionValueLayer<T> {
154    fn clone(&self) -> Self {
155        Self {
156            header_name: self.header_name.clone(),
157            optional: self.optional,
158            _marker: PhantomData,
159        }
160    }
161}
162
163impl<T> HeaderOptionValueLayer<T> {
164    /// Create a new [`HeaderOptionValueLayer`] with the given header name,
165    /// on which optionally create the valu,
166    /// and which will fail if the header is missing.
167    pub fn required(header_name: HeaderName) -> Self {
168        Self {
169            header_name,
170            optional: false,
171            _marker: PhantomData,
172        }
173    }
174
175    /// Create a new [`HeaderOptionValueLayer`] with the given header name,
176    /// on which optionally create the valu,
177    /// and which will gracefully accept if the header is missing.
178    pub fn optional(header_name: HeaderName) -> Self {
179        Self {
180            header_name,
181            optional: true,
182            _marker: PhantomData,
183        }
184    }
185}
186
187impl<T, S> Layer<S> for HeaderOptionValueLayer<T> {
188    type Service = HeaderOptionValueService<T, S>;
189
190    fn layer(&self, inner: S) -> Self::Service {
191        HeaderOptionValueService::new(inner, self.header_name.clone(), self.optional)
192    }
193}
194
195#[cfg(test)]
196mod test {
197    use super::*;
198    use crate::Method;
199
200    #[derive(Debug, Clone, Default)]
201    struct UnitValue;
202
203    #[tokio::test]
204    async fn test_header_option_value_required_happy_path() {
205        let test_cases = [
206            ("1", true),
207            ("true", true),
208            ("True", true),
209            ("TrUE", true),
210            ("TRUE", true),
211            ("0", false),
212            ("false", false),
213            ("False", false),
214            ("FaLsE", false),
215            ("FALSE", false),
216        ];
217        for (str_value, expected_output) in test_cases {
218            let request = Request::builder()
219                .method(Method::GET)
220                .uri("https://www.example.com")
221                .header("x-unit-value", str_value)
222                .body(())
223                .unwrap();
224
225            let inner_service = rama_core::service::service_fn(
226                move |ctx: Context<()>, _req: Request<()>| async move {
227                    assert_eq!(expected_output, ctx.contains::<UnitValue>());
228                    Ok::<_, std::convert::Infallible>(())
229                },
230            );
231
232            let service = HeaderOptionValueService::<UnitValue, _>::required(
233                inner_service,
234                HeaderName::from_static("x-unit-value"),
235            );
236
237            service.serve(Context::default(), request).await.unwrap();
238        }
239    }
240
241    #[tokio::test]
242    async fn test_header_option_value_optional_found() {
243        let test_cases = [
244            ("1", true),
245            ("true", true),
246            ("True", true),
247            ("TrUE", true),
248            ("TRUE", true),
249            ("0", false),
250            ("false", false),
251            ("False", false),
252            ("FaLsE", false),
253            ("FALSE", false),
254        ];
255        for (str_value, expected_output) in test_cases {
256            let request = Request::builder()
257                .method(Method::GET)
258                .uri("https://www.example.com")
259                .header("x-unit-value", str_value)
260                .body(())
261                .unwrap();
262
263            let inner_service = rama_core::service::service_fn(
264                move |ctx: Context<()>, _req: Request<()>| async move {
265                    assert_eq!(expected_output, ctx.contains::<UnitValue>());
266                    Ok::<_, std::convert::Infallible>(())
267                },
268            );
269
270            let service = HeaderOptionValueService::<UnitValue, _>::optional(
271                inner_service,
272                HeaderName::from_static("x-unit-value"),
273            );
274
275            service.serve(Context::default(), request).await.unwrap();
276        }
277    }
278
279    #[tokio::test]
280    async fn test_header_option_value_optional_missing() {
281        let request = Request::builder()
282            .method(Method::GET)
283            .uri("https://www.example.com")
284            .body(())
285            .unwrap();
286
287        let inner_service =
288            rama_core::service::service_fn(|ctx: Context<()>, _req: Request<()>| async move {
289                assert!(!ctx.contains::<UnitValue>());
290
291                Ok::<_, std::convert::Infallible>(())
292            });
293
294        let service = HeaderOptionValueService::<UnitValue, _>::optional(
295            inner_service,
296            HeaderName::from_static("x-unit-value"),
297        );
298
299        service.serve(Context::default(), request).await.unwrap();
300    }
301
302    #[tokio::test]
303    async fn test_header_option_value_required_missing_header() {
304        let request = Request::builder()
305            .method(Method::GET)
306            .uri("https://www.example.com")
307            .body(())
308            .unwrap();
309
310        let inner_service = rama_core::service::service_fn(|_: Request<()>| async move {
311            Ok::<_, std::convert::Infallible>(())
312        });
313
314        let service = HeaderOptionValueService::<UnitValue, _>::required(
315            inner_service,
316            HeaderName::from_static("x-unit-value"),
317        );
318
319        let result = service.serve(Context::default(), request).await;
320        assert!(result.is_err());
321    }
322
323    #[tokio::test]
324    async fn test_header_option_value_required_invalid_value() {
325        let test_cases = ["", "foo", "yes"];
326
327        for test_case in test_cases {
328            let request = Request::builder()
329                .method(Method::GET)
330                .uri("https://www.example.com")
331                .header("x-unit-value", test_case)
332                .body(())
333                .unwrap();
334
335            let inner_service = rama_core::service::service_fn(|_: Request<()>| async move {
336                Ok::<_, std::convert::Infallible>(())
337            });
338
339            let service = HeaderOptionValueService::<UnitValue, _>::required(
340                inner_service,
341                HeaderName::from_static("x-unit-value"),
342            );
343
344            let result = service.serve(Context::default(), request).await;
345            assert!(result.is_err());
346        }
347    }
348
349    #[tokio::test]
350    async fn test_header_option_value_optional_invalid_value() {
351        let test_cases = ["", "foo", "yes"];
352
353        for test_case in test_cases {
354            let request = Request::builder()
355                .method(Method::GET)
356                .uri("https://www.example.com")
357                .header("x-unit-value", test_case)
358                .body(())
359                .unwrap();
360
361            let inner_service = rama_core::service::service_fn(|_: Request<()>| async move {
362                Ok::<_, std::convert::Infallible>(())
363            });
364
365            let service = HeaderOptionValueService::<UnitValue, _>::optional(
366                inner_service,
367                HeaderName::from_static("x-unit-value"),
368            );
369
370            let result = service.serve(Context::default(), request).await;
371            assert!(result.is_err());
372        }
373    }
374}