1use crate::{HeaderName, Request, utils::HeaderValueGetter};
7use rama_core::{
8 Context, Layer, Service,
9 error::{BoxError, ErrorExt, OpaqueError},
10};
11use rama_utils::macros::define_inner_service_accessors;
12use std::{fmt, marker::PhantomData};
13
14pub struct HeaderOptionValueService<T, S> {
18 inner: S,
19 header_name: HeaderName,
20 optional: bool,
21 _marker: PhantomData<fn() -> T>,
22}
23
24impl<T, S> HeaderOptionValueService<T, S> {
25 pub const fn new(inner: S, header_name: HeaderName, optional: bool) -> Self {
30 Self {
31 inner,
32 header_name,
33 optional,
34 _marker: PhantomData,
35 }
36 }
37
38 define_inner_service_accessors!();
39
40 pub const fn required(inner: S, header_name: HeaderName) -> Self {
44 Self::new(inner, header_name, false)
45 }
46
47 pub const fn optional(inner: S, header_name: HeaderName) -> Self {
51 Self::new(inner, header_name, true)
52 }
53}
54
55impl<T, S: fmt::Debug> fmt::Debug for HeaderOptionValueService<T, S> {
56 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57 f.debug_struct("HeaderOptionValueService")
58 .field("inner", &self.inner)
59 .field("header_name", &self.header_name)
60 .field("optional", &self.optional)
61 .field(
62 "_marker",
63 &format_args!("{}", std::any::type_name::<fn() -> T>()),
64 )
65 .finish()
66 }
67}
68
69impl<T, S> Clone for HeaderOptionValueService<T, S>
70where
71 S: Clone,
72{
73 fn clone(&self) -> Self {
74 Self {
75 inner: self.inner.clone(),
76 header_name: self.header_name.clone(),
77 optional: self.optional,
78 _marker: PhantomData,
79 }
80 }
81}
82
83impl<T, S, State, Body, E> Service<State, Request<Body>> for HeaderOptionValueService<T, S>
84where
85 S: Service<State, Request<Body>, Error = E>,
86 T: Default + Clone + Send + Sync + 'static,
87 State: Clone + Send + Sync + 'static,
88 Body: Send + Sync + 'static,
89 E: Into<BoxError> + Send + Sync + 'static,
90{
91 type Response = S::Response;
92 type Error = BoxError;
93
94 async fn serve(
95 &self,
96 mut ctx: Context<State>,
97 request: Request<Body>,
98 ) -> Result<Self::Response, Self::Error> {
99 match request.header_str(&self.header_name) {
100 Ok(str_value) => {
101 let str_value = str_value.trim();
102 if str_value == "1" || str_value.eq_ignore_ascii_case("true") {
103 ctx.insert(T::default());
104 } else if str_value != "0" && !str_value.eq_ignore_ascii_case("false") {
105 return Err(OpaqueError::from_display(format!(
106 "invalid '{}' header option: '{}'",
107 self.header_name, str_value
108 ))
109 .into_boxed());
110 }
111 }
112 Err(err) => {
113 if self.optional && matches!(err, crate::utils::HeaderValueErr::HeaderMissing(_)) {
114 tracing::debug!(
115 error = %err,
116 header_name = %self.header_name,
117 "failed to determine header option",
118 );
119 return self.inner.serve(ctx, request).await.map_err(Into::into);
120 } else {
121 return Err(err
122 .with_context(|| format!("determine '{}' header option", self.header_name))
123 .into_boxed());
124 }
125 }
126 };
127 self.inner.serve(ctx, request).await.map_err(Into::into)
128 }
129}
130
131pub struct HeaderOptionValueLayer<T> {
135 header_name: HeaderName,
136 optional: bool,
137 _marker: PhantomData<fn() -> T>,
138}
139
140impl<T> fmt::Debug for HeaderOptionValueLayer<T> {
141 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
142 f.debug_struct("HeaderOptionValueLayer")
143 .field("header_name", &self.header_name)
144 .field("optional", &self.optional)
145 .field(
146 "_marker",
147 &format_args!("{}", std::any::type_name::<fn() -> T>()),
148 )
149 .finish()
150 }
151}
152
153impl<T> Clone for HeaderOptionValueLayer<T> {
154 fn clone(&self) -> Self {
155 Self {
156 header_name: self.header_name.clone(),
157 optional: self.optional,
158 _marker: PhantomData,
159 }
160 }
161}
162
163impl<T> HeaderOptionValueLayer<T> {
164 pub fn required(header_name: HeaderName) -> Self {
168 Self {
169 header_name,
170 optional: false,
171 _marker: PhantomData,
172 }
173 }
174
175 pub fn optional(header_name: HeaderName) -> Self {
179 Self {
180 header_name,
181 optional: true,
182 _marker: PhantomData,
183 }
184 }
185}
186
187impl<T, S> Layer<S> for HeaderOptionValueLayer<T> {
188 type Service = HeaderOptionValueService<T, S>;
189
190 fn layer(&self, inner: S) -> Self::Service {
191 HeaderOptionValueService::new(inner, self.header_name.clone(), self.optional)
192 }
193
194 fn into_layer(self, inner: S) -> Self::Service {
195 HeaderOptionValueService::new(inner, self.header_name, self.optional)
196 }
197}
198
199#[cfg(test)]
200mod test {
201 use super::*;
202 use crate::Method;
203
204 #[derive(Debug, Clone, Default)]
205 struct UnitValue;
206
207 #[tokio::test]
208 async fn test_header_option_value_required_happy_path() {
209 let test_cases = [
210 ("1", true),
211 ("true", true),
212 ("True", true),
213 ("TrUE", true),
214 ("TRUE", true),
215 ("0", false),
216 ("false", false),
217 ("False", false),
218 ("FaLsE", false),
219 ("FALSE", false),
220 ];
221 for (str_value, expected_output) in test_cases {
222 let request = Request::builder()
223 .method(Method::GET)
224 .uri("https://www.example.com")
225 .header("x-unit-value", str_value)
226 .body(())
227 .unwrap();
228
229 let inner_service = rama_core::service::service_fn(
230 move |ctx: Context<()>, _req: Request<()>| async move {
231 assert_eq!(expected_output, ctx.contains::<UnitValue>());
232 Ok::<_, std::convert::Infallible>(())
233 },
234 );
235
236 let service = HeaderOptionValueService::<UnitValue, _>::required(
237 inner_service,
238 HeaderName::from_static("x-unit-value"),
239 );
240
241 service.serve(Context::default(), request).await.unwrap();
242 }
243 }
244
245 #[tokio::test]
246 async fn test_header_option_value_optional_found() {
247 let test_cases = [
248 ("1", true),
249 ("true", true),
250 ("True", true),
251 ("TrUE", true),
252 ("TRUE", true),
253 ("0", false),
254 ("false", false),
255 ("False", false),
256 ("FaLsE", false),
257 ("FALSE", false),
258 ];
259 for (str_value, expected_output) in test_cases {
260 let request = Request::builder()
261 .method(Method::GET)
262 .uri("https://www.example.com")
263 .header("x-unit-value", str_value)
264 .body(())
265 .unwrap();
266
267 let inner_service = rama_core::service::service_fn(
268 move |ctx: Context<()>, _req: Request<()>| async move {
269 assert_eq!(expected_output, ctx.contains::<UnitValue>());
270 Ok::<_, std::convert::Infallible>(())
271 },
272 );
273
274 let service = HeaderOptionValueService::<UnitValue, _>::optional(
275 inner_service,
276 HeaderName::from_static("x-unit-value"),
277 );
278
279 service.serve(Context::default(), request).await.unwrap();
280 }
281 }
282
283 #[tokio::test]
284 async fn test_header_option_value_optional_missing() {
285 let request = Request::builder()
286 .method(Method::GET)
287 .uri("https://www.example.com")
288 .body(())
289 .unwrap();
290
291 let inner_service =
292 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
293 assert!(!ctx.contains::<UnitValue>());
294
295 Ok::<_, std::convert::Infallible>(())
296 });
297
298 let service = HeaderOptionValueService::<UnitValue, _>::optional(
299 inner_service,
300 HeaderName::from_static("x-unit-value"),
301 );
302
303 service.serve(Context::default(), request).await.unwrap();
304 }
305
306 #[tokio::test]
307 async fn test_header_option_value_required_missing_header() {
308 let request = Request::builder()
309 .method(Method::GET)
310 .uri("https://www.example.com")
311 .body(())
312 .unwrap();
313
314 let inner_service = rama_core::service::service_fn(async |_: Request<()>| {
315 Ok::<_, std::convert::Infallible>(())
316 });
317
318 let service = HeaderOptionValueService::<UnitValue, _>::required(
319 inner_service,
320 HeaderName::from_static("x-unit-value"),
321 );
322
323 let result = service.serve(Context::default(), request).await;
324 assert!(result.is_err());
325 }
326
327 #[tokio::test]
328 async fn test_header_option_value_required_invalid_value() {
329 let test_cases = ["", "foo", "yes"];
330
331 for test_case in test_cases {
332 let request = Request::builder()
333 .method(Method::GET)
334 .uri("https://www.example.com")
335 .header("x-unit-value", test_case)
336 .body(())
337 .unwrap();
338
339 let inner_service = rama_core::service::service_fn(async |_: Request<()>| {
340 Ok::<_, std::convert::Infallible>(())
341 });
342
343 let service = HeaderOptionValueService::<UnitValue, _>::required(
344 inner_service,
345 HeaderName::from_static("x-unit-value"),
346 );
347
348 let result = service.serve(Context::default(), request).await;
349 assert!(result.is_err());
350 }
351 }
352
353 #[tokio::test]
354 async fn test_header_option_value_optional_invalid_value() {
355 let test_cases = ["", "foo", "yes"];
356
357 for test_case in test_cases {
358 let request = Request::builder()
359 .method(Method::GET)
360 .uri("https://www.example.com")
361 .header("x-unit-value", test_case)
362 .body(())
363 .unwrap();
364
365 let inner_service = rama_core::service::service_fn(async |_: Request<()>| {
366 Ok::<_, std::convert::Infallible>(())
367 });
368
369 let service = HeaderOptionValueService::<UnitValue, _>::optional(
370 inner_service,
371 HeaderName::from_static("x-unit-value"),
372 );
373
374 let result = service.serve(Context::default(), request).await;
375 assert!(result.is_err());
376 }
377 }
378}