1use crate::{utils::HeaderValueGetter, HeaderName, Request};
7use rama_core::{
8 error::{BoxError, ErrorExt, OpaqueError},
9 Context, Layer, Service,
10};
11use rama_utils::macros::define_inner_service_accessors;
12use std::{fmt, marker::PhantomData};
13
14pub struct HeaderOptionValueService<T, S> {
18 inner: S,
19 header_name: HeaderName,
20 optional: bool,
21 _marker: PhantomData<fn() -> T>,
22}
23
24impl<T, S> HeaderOptionValueService<T, S> {
25 pub const fn new(inner: S, header_name: HeaderName, optional: bool) -> Self {
30 Self {
31 inner,
32 header_name,
33 optional,
34 _marker: PhantomData,
35 }
36 }
37
38 define_inner_service_accessors!();
39
40 pub const fn required(inner: S, header_name: HeaderName) -> Self {
44 Self::new(inner, header_name, false)
45 }
46
47 pub const fn optional(inner: S, header_name: HeaderName) -> Self {
51 Self::new(inner, header_name, true)
52 }
53}
54
55impl<T, S: fmt::Debug> fmt::Debug for HeaderOptionValueService<T, S> {
56 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57 f.debug_struct("HeaderOptionValueService")
58 .field("inner", &self.inner)
59 .field("header_name", &self.header_name)
60 .field("optional", &self.optional)
61 .field(
62 "_marker",
63 &format_args!("{}", std::any::type_name::<fn() -> T>()),
64 )
65 .finish()
66 }
67}
68
69impl<T, S> Clone for HeaderOptionValueService<T, S>
70where
71 S: Clone,
72{
73 fn clone(&self) -> Self {
74 Self {
75 inner: self.inner.clone(),
76 header_name: self.header_name.clone(),
77 optional: self.optional,
78 _marker: PhantomData,
79 }
80 }
81}
82
83impl<T, S, State, Body, E> Service<State, Request<Body>> for HeaderOptionValueService<T, S>
84where
85 S: Service<State, Request<Body>, Error = E>,
86 T: Default + Clone + Send + Sync + 'static,
87 State: Clone + Send + Sync + 'static,
88 Body: Send + Sync + 'static,
89 E: Into<BoxError> + Send + Sync + 'static,
90{
91 type Response = S::Response;
92 type Error = BoxError;
93
94 async fn serve(
95 &self,
96 mut ctx: Context<State>,
97 request: Request<Body>,
98 ) -> Result<Self::Response, Self::Error> {
99 match request.header_str(&self.header_name) {
100 Ok(str_value) => {
101 let str_value = str_value.trim();
102 if str_value == "1" || str_value.eq_ignore_ascii_case("true") {
103 ctx.insert(T::default());
104 } else if str_value != "0" && !str_value.eq_ignore_ascii_case("false") {
105 return Err(OpaqueError::from_display(format!(
106 "invalid '{}' header option: '{}'",
107 self.header_name, str_value
108 ))
109 .into_boxed());
110 }
111 }
112 Err(err) => {
113 if self.optional && matches!(err, crate::utils::HeaderValueErr::HeaderMissing(_)) {
114 tracing::debug!(
115 error = %err,
116 header_name = %self.header_name,
117 "failed to determine header option",
118 );
119 return self.inner.serve(ctx, request).await.map_err(Into::into);
120 } else {
121 return Err(err
122 .with_context(|| format!("determine '{}' header option", self.header_name))
123 .into_boxed());
124 }
125 }
126 };
127 self.inner.serve(ctx, request).await.map_err(Into::into)
128 }
129}
130
131pub struct HeaderOptionValueLayer<T> {
135 header_name: HeaderName,
136 optional: bool,
137 _marker: PhantomData<fn() -> T>,
138}
139
140impl<T> fmt::Debug for HeaderOptionValueLayer<T> {
141 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
142 f.debug_struct("HeaderOptionValueLayer")
143 .field("header_name", &self.header_name)
144 .field("optional", &self.optional)
145 .field(
146 "_marker",
147 &format_args!("{}", std::any::type_name::<fn() -> T>()),
148 )
149 .finish()
150 }
151}
152
153impl<T> Clone for HeaderOptionValueLayer<T> {
154 fn clone(&self) -> Self {
155 Self {
156 header_name: self.header_name.clone(),
157 optional: self.optional,
158 _marker: PhantomData,
159 }
160 }
161}
162
163impl<T> HeaderOptionValueLayer<T> {
164 pub fn required(header_name: HeaderName) -> Self {
168 Self {
169 header_name,
170 optional: false,
171 _marker: PhantomData,
172 }
173 }
174
175 pub fn optional(header_name: HeaderName) -> Self {
179 Self {
180 header_name,
181 optional: true,
182 _marker: PhantomData,
183 }
184 }
185}
186
187impl<T, S> Layer<S> for HeaderOptionValueLayer<T> {
188 type Service = HeaderOptionValueService<T, S>;
189
190 fn layer(&self, inner: S) -> Self::Service {
191 HeaderOptionValueService::new(inner, self.header_name.clone(), self.optional)
192 }
193}
194
195#[cfg(test)]
196mod test {
197 use super::*;
198 use crate::Method;
199
200 #[derive(Debug, Clone, Default)]
201 struct UnitValue;
202
203 #[tokio::test]
204 async fn test_header_option_value_required_happy_path() {
205 let test_cases = [
206 ("1", true),
207 ("true", true),
208 ("True", true),
209 ("TrUE", true),
210 ("TRUE", true),
211 ("0", false),
212 ("false", false),
213 ("False", false),
214 ("FaLsE", false),
215 ("FALSE", false),
216 ];
217 for (str_value, expected_output) in test_cases {
218 let request = Request::builder()
219 .method(Method::GET)
220 .uri("https://www.example.com")
221 .header("x-unit-value", str_value)
222 .body(())
223 .unwrap();
224
225 let inner_service = rama_core::service::service_fn(
226 move |ctx: Context<()>, _req: Request<()>| async move {
227 assert_eq!(expected_output, ctx.contains::<UnitValue>());
228 Ok::<_, std::convert::Infallible>(())
229 },
230 );
231
232 let service = HeaderOptionValueService::<UnitValue, _>::required(
233 inner_service,
234 HeaderName::from_static("x-unit-value"),
235 );
236
237 service.serve(Context::default(), request).await.unwrap();
238 }
239 }
240
241 #[tokio::test]
242 async fn test_header_option_value_optional_found() {
243 let test_cases = [
244 ("1", true),
245 ("true", true),
246 ("True", true),
247 ("TrUE", true),
248 ("TRUE", true),
249 ("0", false),
250 ("false", false),
251 ("False", false),
252 ("FaLsE", false),
253 ("FALSE", false),
254 ];
255 for (str_value, expected_output) in test_cases {
256 let request = Request::builder()
257 .method(Method::GET)
258 .uri("https://www.example.com")
259 .header("x-unit-value", str_value)
260 .body(())
261 .unwrap();
262
263 let inner_service = rama_core::service::service_fn(
264 move |ctx: Context<()>, _req: Request<()>| async move {
265 assert_eq!(expected_output, ctx.contains::<UnitValue>());
266 Ok::<_, std::convert::Infallible>(())
267 },
268 );
269
270 let service = HeaderOptionValueService::<UnitValue, _>::optional(
271 inner_service,
272 HeaderName::from_static("x-unit-value"),
273 );
274
275 service.serve(Context::default(), request).await.unwrap();
276 }
277 }
278
279 #[tokio::test]
280 async fn test_header_option_value_optional_missing() {
281 let request = Request::builder()
282 .method(Method::GET)
283 .uri("https://www.example.com")
284 .body(())
285 .unwrap();
286
287 let inner_service =
288 rama_core::service::service_fn(|ctx: Context<()>, _req: Request<()>| async move {
289 assert!(!ctx.contains::<UnitValue>());
290
291 Ok::<_, std::convert::Infallible>(())
292 });
293
294 let service = HeaderOptionValueService::<UnitValue, _>::optional(
295 inner_service,
296 HeaderName::from_static("x-unit-value"),
297 );
298
299 service.serve(Context::default(), request).await.unwrap();
300 }
301
302 #[tokio::test]
303 async fn test_header_option_value_required_missing_header() {
304 let request = Request::builder()
305 .method(Method::GET)
306 .uri("https://www.example.com")
307 .body(())
308 .unwrap();
309
310 let inner_service = rama_core::service::service_fn(|_: Request<()>| async move {
311 Ok::<_, std::convert::Infallible>(())
312 });
313
314 let service = HeaderOptionValueService::<UnitValue, _>::required(
315 inner_service,
316 HeaderName::from_static("x-unit-value"),
317 );
318
319 let result = service.serve(Context::default(), request).await;
320 assert!(result.is_err());
321 }
322
323 #[tokio::test]
324 async fn test_header_option_value_required_invalid_value() {
325 let test_cases = ["", "foo", "yes"];
326
327 for test_case in test_cases {
328 let request = Request::builder()
329 .method(Method::GET)
330 .uri("https://www.example.com")
331 .header("x-unit-value", test_case)
332 .body(())
333 .unwrap();
334
335 let inner_service = rama_core::service::service_fn(|_: Request<()>| async move {
336 Ok::<_, std::convert::Infallible>(())
337 });
338
339 let service = HeaderOptionValueService::<UnitValue, _>::required(
340 inner_service,
341 HeaderName::from_static("x-unit-value"),
342 );
343
344 let result = service.serve(Context::default(), request).await;
345 assert!(result.is_err());
346 }
347 }
348
349 #[tokio::test]
350 async fn test_header_option_value_optional_invalid_value() {
351 let test_cases = ["", "foo", "yes"];
352
353 for test_case in test_cases {
354 let request = Request::builder()
355 .method(Method::GET)
356 .uri("https://www.example.com")
357 .header("x-unit-value", test_case)
358 .body(())
359 .unwrap();
360
361 let inner_service = rama_core::service::service_fn(|_: Request<()>| async move {
362 Ok::<_, std::convert::Infallible>(())
363 });
364
365 let service = HeaderOptionValueService::<UnitValue, _>::optional(
366 inner_service,
367 HeaderName::from_static("x-unit-value"),
368 );
369
370 let result = service.serve(Context::default(), request).await;
371 assert!(result.is_err());
372 }
373 }
374}