rama_http/layer/
header_option_value.rs

1//! Similar to [`super::header_config::HeaderConfigLayer`],
2//! but storing the [`Default`] value of type `T` in case
3//! the header with the given [`HeaderName`] is present
4//! and has a bool-like value.
5
6use crate::{HeaderName, Request, utils::HeaderValueGetter};
7use rama_core::{
8    Context, Layer, Service,
9    error::{BoxError, ErrorExt, OpaqueError},
10};
11use rama_utils::macros::define_inner_service_accessors;
12use std::{fmt, marker::PhantomData};
13
14/// A [`Service`] which stores the [`Default`] value of type `T` in case
15/// the header with the given [`HeaderName`] is present
16/// and has a bool-like value.
17pub struct HeaderOptionValueService<T, S> {
18    inner: S,
19    header_name: HeaderName,
20    optional: bool,
21    _marker: PhantomData<fn() -> T>,
22}
23
24impl<T, S> HeaderOptionValueService<T, S> {
25    /// Create a new [`HeaderOptionValueService`].
26    ///
27    /// Alias for [`HeaderOptionValueService::required`] if `!optional`
28    /// and [`HeaderOptionValueService::optional`] if `optional`.
29    pub const fn new(inner: S, header_name: HeaderName, optional: bool) -> Self {
30        Self {
31            inner,
32            header_name,
33            optional,
34            _marker: PhantomData,
35        }
36    }
37
38    define_inner_service_accessors!();
39
40    /// Create a new [`HeaderOptionValueService`] with the given inner service
41    /// and header name, on which optionally create the value,
42    /// and which will fail if the header is missing.
43    pub const fn required(inner: S, header_name: HeaderName) -> Self {
44        Self::new(inner, header_name, false)
45    }
46
47    /// Create a new [`HeaderOptionValueService`] with the given inner service
48    /// and header name, on which optionally create the value,
49    /// and which will gracefully accept if the header is missing.
50    pub const fn optional(inner: S, header_name: HeaderName) -> Self {
51        Self::new(inner, header_name, true)
52    }
53}
54
55impl<T, S: fmt::Debug> fmt::Debug for HeaderOptionValueService<T, S> {
56    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57        f.debug_struct("HeaderOptionValueService")
58            .field("inner", &self.inner)
59            .field("header_name", &self.header_name)
60            .field("optional", &self.optional)
61            .field(
62                "_marker",
63                &format_args!("{}", std::any::type_name::<fn() -> T>()),
64            )
65            .finish()
66    }
67}
68
69impl<T, S> Clone for HeaderOptionValueService<T, S>
70where
71    S: Clone,
72{
73    fn clone(&self) -> Self {
74        Self {
75            inner: self.inner.clone(),
76            header_name: self.header_name.clone(),
77            optional: self.optional,
78            _marker: PhantomData,
79        }
80    }
81}
82
83impl<T, S, State, Body, E> Service<State, Request<Body>> for HeaderOptionValueService<T, S>
84where
85    S: Service<State, Request<Body>, Error = E>,
86    T: Default + Clone + Send + Sync + 'static,
87    State: Clone + Send + Sync + 'static,
88    Body: Send + Sync + 'static,
89    E: Into<BoxError> + Send + Sync + 'static,
90{
91    type Response = S::Response;
92    type Error = BoxError;
93
94    async fn serve(
95        &self,
96        mut ctx: Context<State>,
97        request: Request<Body>,
98    ) -> Result<Self::Response, Self::Error> {
99        match request.header_str(&self.header_name) {
100            Ok(str_value) => {
101                let str_value = str_value.trim();
102                if str_value == "1" || str_value.eq_ignore_ascii_case("true") {
103                    ctx.insert(T::default());
104                } else if str_value != "0" && !str_value.eq_ignore_ascii_case("false") {
105                    return Err(OpaqueError::from_display(format!(
106                        "invalid '{}' header option: '{}'",
107                        self.header_name, str_value
108                    ))
109                    .into_boxed());
110                }
111            }
112            Err(err) => {
113                if self.optional && matches!(err, crate::utils::HeaderValueErr::HeaderMissing(_)) {
114                    tracing::debug!(
115                        error = %err,
116                        header_name = %self.header_name,
117                        "failed to determine header option",
118                    );
119                    return self.inner.serve(ctx, request).await.map_err(Into::into);
120                } else {
121                    return Err(err
122                        .with_context(|| format!("determine '{}' header option", self.header_name))
123                        .into_boxed());
124                }
125            }
126        };
127        self.inner.serve(ctx, request).await.map_err(Into::into)
128    }
129}
130
131/// Layer which stores the [`Default`] value of type `T` in case
132/// the header with the given [`HeaderName`] is present
133/// and has a bool-like value.
134pub struct HeaderOptionValueLayer<T> {
135    header_name: HeaderName,
136    optional: bool,
137    _marker: PhantomData<fn() -> T>,
138}
139
140impl<T> fmt::Debug for HeaderOptionValueLayer<T> {
141    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
142        f.debug_struct("HeaderOptionValueLayer")
143            .field("header_name", &self.header_name)
144            .field("optional", &self.optional)
145            .field(
146                "_marker",
147                &format_args!("{}", std::any::type_name::<fn() -> T>()),
148            )
149            .finish()
150    }
151}
152
153impl<T> Clone for HeaderOptionValueLayer<T> {
154    fn clone(&self) -> Self {
155        Self {
156            header_name: self.header_name.clone(),
157            optional: self.optional,
158            _marker: PhantomData,
159        }
160    }
161}
162
163impl<T> HeaderOptionValueLayer<T> {
164    /// Create a new [`HeaderOptionValueLayer`] with the given header name,
165    /// on which optionally create the valu,
166    /// and which will fail if the header is missing.
167    pub fn required(header_name: HeaderName) -> Self {
168        Self {
169            header_name,
170            optional: false,
171            _marker: PhantomData,
172        }
173    }
174
175    /// Create a new [`HeaderOptionValueLayer`] with the given header name,
176    /// on which optionally create the valu,
177    /// and which will gracefully accept if the header is missing.
178    pub fn optional(header_name: HeaderName) -> Self {
179        Self {
180            header_name,
181            optional: true,
182            _marker: PhantomData,
183        }
184    }
185}
186
187impl<T, S> Layer<S> for HeaderOptionValueLayer<T> {
188    type Service = HeaderOptionValueService<T, S>;
189
190    fn layer(&self, inner: S) -> Self::Service {
191        HeaderOptionValueService::new(inner, self.header_name.clone(), self.optional)
192    }
193
194    fn into_layer(self, inner: S) -> Self::Service {
195        HeaderOptionValueService::new(inner, self.header_name, self.optional)
196    }
197}
198
199#[cfg(test)]
200mod test {
201    use super::*;
202    use crate::Method;
203
204    #[derive(Debug, Clone, Default)]
205    struct UnitValue;
206
207    #[tokio::test]
208    async fn test_header_option_value_required_happy_path() {
209        let test_cases = [
210            ("1", true),
211            ("true", true),
212            ("True", true),
213            ("TrUE", true),
214            ("TRUE", true),
215            ("0", false),
216            ("false", false),
217            ("False", false),
218            ("FaLsE", false),
219            ("FALSE", false),
220        ];
221        for (str_value, expected_output) in test_cases {
222            let request = Request::builder()
223                .method(Method::GET)
224                .uri("https://www.example.com")
225                .header("x-unit-value", str_value)
226                .body(())
227                .unwrap();
228
229            let inner_service = rama_core::service::service_fn(
230                move |ctx: Context<()>, _req: Request<()>| async move {
231                    assert_eq!(expected_output, ctx.contains::<UnitValue>());
232                    Ok::<_, std::convert::Infallible>(())
233                },
234            );
235
236            let service = HeaderOptionValueService::<UnitValue, _>::required(
237                inner_service,
238                HeaderName::from_static("x-unit-value"),
239            );
240
241            service.serve(Context::default(), request).await.unwrap();
242        }
243    }
244
245    #[tokio::test]
246    async fn test_header_option_value_optional_found() {
247        let test_cases = [
248            ("1", true),
249            ("true", true),
250            ("True", true),
251            ("TrUE", true),
252            ("TRUE", true),
253            ("0", false),
254            ("false", false),
255            ("False", false),
256            ("FaLsE", false),
257            ("FALSE", false),
258        ];
259        for (str_value, expected_output) in test_cases {
260            let request = Request::builder()
261                .method(Method::GET)
262                .uri("https://www.example.com")
263                .header("x-unit-value", str_value)
264                .body(())
265                .unwrap();
266
267            let inner_service = rama_core::service::service_fn(
268                move |ctx: Context<()>, _req: Request<()>| async move {
269                    assert_eq!(expected_output, ctx.contains::<UnitValue>());
270                    Ok::<_, std::convert::Infallible>(())
271                },
272            );
273
274            let service = HeaderOptionValueService::<UnitValue, _>::optional(
275                inner_service,
276                HeaderName::from_static("x-unit-value"),
277            );
278
279            service.serve(Context::default(), request).await.unwrap();
280        }
281    }
282
283    #[tokio::test]
284    async fn test_header_option_value_optional_missing() {
285        let request = Request::builder()
286            .method(Method::GET)
287            .uri("https://www.example.com")
288            .body(())
289            .unwrap();
290
291        let inner_service =
292            rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
293                assert!(!ctx.contains::<UnitValue>());
294
295                Ok::<_, std::convert::Infallible>(())
296            });
297
298        let service = HeaderOptionValueService::<UnitValue, _>::optional(
299            inner_service,
300            HeaderName::from_static("x-unit-value"),
301        );
302
303        service.serve(Context::default(), request).await.unwrap();
304    }
305
306    #[tokio::test]
307    async fn test_header_option_value_required_missing_header() {
308        let request = Request::builder()
309            .method(Method::GET)
310            .uri("https://www.example.com")
311            .body(())
312            .unwrap();
313
314        let inner_service = rama_core::service::service_fn(async |_: Request<()>| {
315            Ok::<_, std::convert::Infallible>(())
316        });
317
318        let service = HeaderOptionValueService::<UnitValue, _>::required(
319            inner_service,
320            HeaderName::from_static("x-unit-value"),
321        );
322
323        let result = service.serve(Context::default(), request).await;
324        assert!(result.is_err());
325    }
326
327    #[tokio::test]
328    async fn test_header_option_value_required_invalid_value() {
329        let test_cases = ["", "foo", "yes"];
330
331        for test_case in test_cases {
332            let request = Request::builder()
333                .method(Method::GET)
334                .uri("https://www.example.com")
335                .header("x-unit-value", test_case)
336                .body(())
337                .unwrap();
338
339            let inner_service = rama_core::service::service_fn(async |_: Request<()>| {
340                Ok::<_, std::convert::Infallible>(())
341            });
342
343            let service = HeaderOptionValueService::<UnitValue, _>::required(
344                inner_service,
345                HeaderName::from_static("x-unit-value"),
346            );
347
348            let result = service.serve(Context::default(), request).await;
349            assert!(result.is_err());
350        }
351    }
352
353    #[tokio::test]
354    async fn test_header_option_value_optional_invalid_value() {
355        let test_cases = ["", "foo", "yes"];
356
357        for test_case in test_cases {
358            let request = Request::builder()
359                .method(Method::GET)
360                .uri("https://www.example.com")
361                .header("x-unit-value", test_case)
362                .body(())
363                .unwrap();
364
365            let inner_service = rama_core::service::service_fn(async |_: Request<()>| {
366                Ok::<_, std::convert::Infallible>(())
367            });
368
369            let service = HeaderOptionValueService::<UnitValue, _>::optional(
370                inner_service,
371                HeaderName::from_static("x-unit-value"),
372            );
373
374            let result = service.serve(Context::default(), request).await;
375            assert!(result.is_err());
376        }
377    }
378}