1use crate::HeaderName;
41use crate::{
42 Request,
43 utils::{HeaderValueErr, HeaderValueGetter},
44};
45use rama_core::{Context, Layer, Service, error::BoxError};
46use rama_utils::macros::define_inner_service_accessors;
47use std::iter::FromIterator;
48use std::str::FromStr;
49use std::{fmt, marker::PhantomData};
50
51pub struct HeaderFromStrConfigService<T, S, C = Vec<T>> {
56 inner: S,
57 header_name: HeaderName,
58 optional: bool,
59 repeat: bool,
60 _marker: PhantomData<fn() -> (T, C)>,
61}
62
63impl<T, S, C> HeaderFromStrConfigService<T, S, C> {
64 define_inner_service_accessors!();
65
66 pub const fn required(inner: S, header_name: HeaderName) -> Self {
70 Self {
71 inner,
72 header_name,
73 optional: false,
74 repeat: false,
75 _marker: PhantomData,
76 }
77 }
78
79 pub const fn optional(inner: S, header_name: HeaderName) -> Self {
83 Self {
84 inner,
85 header_name,
86 optional: true,
87 repeat: false,
88 _marker: PhantomData,
89 }
90 }
91
92 pub fn set_repeat(&mut self, repeat: bool) -> &mut Self {
95 self.repeat = repeat;
96 self
97 }
98
99 pub fn with_repeat(mut self, repeat: bool) -> Self {
102 self.repeat = repeat;
103 self
104 }
105}
106
107impl<T, S, C> fmt::Debug for HeaderFromStrConfigService<T, S, C>
108where
109 S: fmt::Debug,
110{
111 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
112 f.debug_struct("HeaderFromStrConfigService")
113 .field("inner", &self.inner)
114 .field("header_name", &self.header_name)
115 .field("optional", &self.optional)
116 .field("repeat", &self.repeat)
117 .field(
118 "_marker",
119 &format_args!("{}", std::any::type_name::<fn() -> (T, C)>()),
120 )
121 .finish()
122 }
123}
124
125impl<T, S, C> Clone for HeaderFromStrConfigService<T, S, C>
126where
127 S: Clone,
128{
129 fn clone(&self) -> Self {
130 Self {
131 inner: self.inner.clone(),
132 header_name: self.header_name.clone(),
133 optional: self.optional,
134 repeat: self.repeat,
135 _marker: PhantomData,
136 }
137 }
138}
139
140impl<T, S, State, Body, E, C> Service<State, Request<Body>> for HeaderFromStrConfigService<T, S, C>
141where
142 S: Service<State, Request<Body>, Error = E>,
143 T: FromStr<Err: Into<BoxError> + Send + Sync + 'static> + Send + Sync + 'static + Clone,
144 C: FromIterator<T> + Send + Sync + 'static + Clone,
145 State: Clone + Send + Sync + 'static,
146 Body: Send + Sync + 'static,
147 E: Into<BoxError> + Send + Sync + 'static,
148{
149 type Response = S::Response;
150 type Error = BoxError;
151
152 async fn serve(
153 &self,
154 mut ctx: Context<State>,
155 request: Request<Body>,
156 ) -> Result<Self::Response, Self::Error> {
157 if self.repeat {
158 let headers = request.headers().get_all(&self.header_name);
159 let mut parsed_values = headers
160 .into_iter()
161 .flat_map(|value| {
162 value.to_str().into_iter().flat_map(|string| {
163 string
164 .split(',')
165 .filter_map(|x| match x.trim() {
166 "" => None,
167 y => Some(y),
168 })
169 .map(|x| x.parse::<T>().map_err(Into::into))
170 })
171 })
172 .peekable();
173
174 if parsed_values.peek().is_none() {
175 if !self.optional {
176 return Err(HeaderValueErr::HeaderMissing(self.header_name.to_string()).into());
177 }
178 } else {
179 let values = parsed_values.collect::<Result<C, _>>()?;
180 ctx.insert(values);
181 }
182 } else {
183 match request.header_str(&self.header_name) {
184 Ok(s) => {
185 let cfg: T = s.parse().map_err(Into::into)?;
186 ctx.insert(cfg);
187 }
188 Err(HeaderValueErr::HeaderMissing(_)) if self.optional => (),
189 Err(err) => {
190 return Err(err.into());
191 }
192 }
193 }
194
195 self.inner.serve(ctx, request).await.map_err(Into::into)
196 }
197}
198
199pub struct HeaderFromStrConfigLayer<T, C = Vec<T>> {
204 header_name: HeaderName,
205 optional: bool,
206 repeat: bool,
207 _marker: PhantomData<fn() -> (T, C)>,
208}
209
210impl<T, C: fmt::Debug> fmt::Debug for HeaderFromStrConfigLayer<T, C> {
211 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
212 f.debug_struct("HeaderFromStrConfigLayer")
213 .field("header_name", &self.header_name)
214 .field("optional", &self.optional)
215 .field("repeat", &self.repeat)
216 .field(
217 "_marker",
218 &format_args!("{}", std::any::type_name::<fn() -> (T, C)>()),
219 )
220 .finish()
221 }
222}
223
224impl<T, C> Clone for HeaderFromStrConfigLayer<T, C> {
225 fn clone(&self) -> Self {
226 Self {
227 header_name: self.header_name.clone(),
228 optional: self.optional,
229 repeat: self.repeat,
230 _marker: PhantomData,
231 }
232 }
233}
234
235impl<T, C> HeaderFromStrConfigLayer<T, C> {
236 pub fn required(header_name: HeaderName) -> Self {
240 Self {
241 header_name,
242 optional: false,
243 repeat: false,
244 _marker: PhantomData,
245 }
246 }
247
248 pub fn optional(header_name: HeaderName) -> Self {
252 Self {
253 header_name,
254 optional: true,
255 repeat: false,
256 _marker: PhantomData,
257 }
258 }
259
260 pub fn set_repeat(&mut self, repeat: bool) -> &mut Self {
263 self.repeat = repeat;
264 self
265 }
266
267 pub fn with_repeat(mut self, repeat: bool) -> Self {
270 self.repeat = repeat;
271 self
272 }
273}
274
275impl<T, S, C> Layer<S> for HeaderFromStrConfigLayer<T, C> {
276 type Service = HeaderFromStrConfigService<T, S, C>;
277
278 fn layer(&self, inner: S) -> Self::Service {
279 HeaderFromStrConfigService {
280 inner,
281 header_name: self.header_name.clone(),
282 optional: self.optional,
283 repeat: self.repeat,
284 _marker: PhantomData,
285 }
286 }
287
288 fn into_layer(self, inner: S) -> Self::Service {
289 HeaderFromStrConfigService {
290 inner,
291 header_name: self.header_name,
292 optional: self.optional,
293 repeat: self.repeat,
294 _marker: PhantomData,
295 }
296 }
297}
298
299#[cfg(test)]
300mod test {
301 use super::*;
302 use crate::Method;
303 use std::collections::{HashSet, LinkedList};
304
305 #[tokio::test]
306 async fn test_header_config_required_happy_path() {
307 let request = Request::builder()
308 .method(Method::GET)
309 .uri("https://www.example.com")
310 .header("x-proxy-id", "42")
311 .body(())
312 .unwrap();
313
314 let inner_service =
315 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
316 let id: &usize = ctx.get().unwrap();
317 assert_eq!(*id, 42);
318
319 Ok::<_, std::convert::Infallible>(())
320 });
321
322 let service = HeaderFromStrConfigService::<usize, _>::required(
323 inner_service,
324 HeaderName::from_static("x-proxy-id"),
325 );
326
327 service.serve(Context::default(), request).await.unwrap();
328 }
329
330 #[tokio::test]
331 async fn test_header_config_required_repeat_happy_path() {
332 let request = Request::builder()
333 .method(Method::GET)
334 .uri("https://www.example.com")
335 .header("x-proxy-labels", "foo,bar ,baz, fin ")
336 .body(())
337 .unwrap();
338
339 let inner_service =
340 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
341 let labels: &Vec<String> = ctx.get().unwrap();
342 assert_eq!("foo+bar+baz+fin", labels.join("+"));
343
344 Ok::<_, std::convert::Infallible>(())
345 });
346
347 let service = HeaderFromStrConfigService::<String, _>::required(
348 inner_service,
349 HeaderName::from_static("x-proxy-labels"),
350 )
351 .with_repeat(true);
352
353 service.serve(Context::default(), request).await.unwrap();
354 }
355
356 #[tokio::test]
357 async fn test_header_config_required_repeat_custom_container() {
358 let request = Request::builder()
359 .method(Method::GET)
360 .uri("https://www.example.com")
361 .header("x-proxy-labels", "foo,bar,baz,foo")
362 .body(())
363 .unwrap();
364
365 let inner_service =
366 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
367 let labels: &HashSet<String> = ctx.get().unwrap();
368 assert_eq!(3, labels.len());
369 assert!(labels.contains("foo"));
370 assert!(labels.contains("bar"));
371 assert!(labels.contains("baz"));
372
373 Ok::<_, std::convert::Infallible>(())
374 });
375
376 let service = HeaderFromStrConfigService::<String, _, HashSet<String>>::required(
377 inner_service,
378 HeaderName::from_static("x-proxy-labels"),
379 )
380 .with_repeat(true);
381
382 service.serve(Context::default(), request).await.unwrap();
383 }
384
385 #[tokio::test]
386 async fn test_header_config_required_repeat_linked_list() {
387 let request = Request::builder()
388 .method(Method::GET)
389 .uri("https://www.example.com")
390 .header("x-proxy-labels", "foo,bar,baz")
391 .body(())
392 .unwrap();
393
394 let inner_service =
395 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
396 let labels: &LinkedList<String> = ctx.get().unwrap();
397 let mut iter = labels.iter();
398 assert_eq!(Some("foo"), iter.next().map(|x| x.as_str()));
399 assert_eq!(Some("bar"), iter.next().map(|x| x.as_str()));
400 assert_eq!(Some("baz"), iter.next().map(|x| x.as_str()));
401 assert_eq!(None, iter.next());
402
403 Ok::<_, std::convert::Infallible>(())
404 });
405
406 let service = HeaderFromStrConfigService::<String, _, LinkedList<String>>::required(
407 inner_service,
408 HeaderName::from_static("x-proxy-labels"),
409 )
410 .with_repeat(true);
411
412 service.serve(Context::default(), request).await.unwrap();
413 }
414
415 #[tokio::test]
416 async fn test_header_config_required_repeat_happy_path_multi_header() {
417 let request = Request::builder()
418 .method(Method::GET)
419 .uri("https://www.example.com")
420 .header("x-proxy-labels", "foo,bar ")
421 .header("x-Proxy-Labels", "baz ")
422 .header("X-PROXY-LABELS", " fin")
423 .body(())
424 .unwrap();
425
426 let inner_service =
427 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
428 let labels: &Vec<String> = ctx.get().unwrap();
429 assert_eq!("foo+bar+baz+fin", labels.join("+"));
430
431 Ok::<_, std::convert::Infallible>(())
432 });
433
434 let service = HeaderFromStrConfigService::<String, _>::required(
435 inner_service,
436 HeaderName::from_static("x-proxy-labels"),
437 )
438 .with_repeat(true);
439
440 service.serve(Context::default(), request).await.unwrap();
441 }
442
443 #[tokio::test]
444 async fn test_header_config_optional_found() {
445 let request = Request::builder()
446 .method(Method::GET)
447 .uri("https://www.example.com")
448 .header("x-proxy-id", "42")
449 .body(())
450 .unwrap();
451
452 let inner_service =
453 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
454 let id: usize = *ctx.get().unwrap();
455 assert_eq!(id, 42);
456
457 Ok::<_, std::convert::Infallible>(())
458 });
459
460 let service = HeaderFromStrConfigService::<usize, _>::optional(
461 inner_service,
462 HeaderName::from_static("x-proxy-id"),
463 );
464
465 service.serve(Context::default(), request).await.unwrap();
466 }
467
468 #[tokio::test]
469 async fn test_header_config_repeat_optional_found() {
470 let request = Request::builder()
471 .method(Method::GET)
472 .uri("https://www.example.com")
473 .header("x-proxy-labels", "foo,bar ,baz, fin ")
474 .body(())
475 .unwrap();
476
477 let inner_service =
478 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
479 let labels: &Vec<String> = ctx.get().unwrap();
480 assert_eq!("foo+bar+baz+fin", labels.join("+"));
481
482 Ok::<_, std::convert::Infallible>(())
483 });
484
485 let service = HeaderFromStrConfigService::<String, _>::optional(
486 inner_service,
487 HeaderName::from_static("x-proxy-labels"),
488 )
489 .with_repeat(true);
490
491 service.serve(Context::default(), request).await.unwrap();
492 }
493
494 #[tokio::test]
495 async fn test_header_config_optional_missing() {
496 let request = Request::builder()
497 .method(Method::GET)
498 .uri("https://www.example.com")
499 .body(())
500 .unwrap();
501
502 let inner_service =
503 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
504 assert!(ctx.get::<usize>().is_none());
505 Ok::<_, std::convert::Infallible>(())
506 });
507
508 let service = HeaderFromStrConfigService::<usize, _>::optional(
509 inner_service,
510 HeaderName::from_static("x-proxy-id"),
511 );
512
513 service.serve(Context::default(), request).await.unwrap();
514 }
515
516 #[tokio::test]
517 async fn test_header_config_repeat_optional_missing() {
518 let request = Request::builder()
519 .method(Method::GET)
520 .uri("https://www.example.com")
521 .body(())
522 .unwrap();
523
524 let inner_service =
525 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
526 assert!(ctx.get::<Vec<String>>().is_none());
527
528 Ok::<_, std::convert::Infallible>(())
529 });
530
531 let service = HeaderFromStrConfigService::<String, _>::optional(
532 inner_service,
533 HeaderName::from_static("x-proxy-labels"),
534 )
535 .with_repeat(true);
536
537 service.serve(Context::default(), request).await.unwrap();
538 }
539
540 #[tokio::test]
541 async fn test_header_config_required_missing_header() {
542 let request = Request::builder()
543 .method(Method::GET)
544 .uri("https://www.example.com")
545 .body(())
546 .unwrap();
547
548 let inner_service =
549 rama_core::service::service_fn(async |_ctx: Context<()>, _req: Request<()>| {
550 Ok::<_, std::convert::Infallible>(())
551 });
552
553 let service = HeaderFromStrConfigService::<usize, _>::required(
554 inner_service,
555 HeaderName::from_static("x-proxy-id"),
556 );
557
558 let result = service.serve(Context::default(), request).await;
559 assert!(result.is_err());
560 }
561
562 #[tokio::test]
563 async fn test_header_config_repeat_required_missing() {
564 let request = Request::builder()
565 .method(Method::GET)
566 .uri("https://www.example.com")
567 .body(())
568 .unwrap();
569
570 let inner_service =
571 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
572 assert!(ctx.get::<Vec<String>>().is_none());
573
574 Ok::<_, std::convert::Infallible>(())
575 });
576
577 let service = HeaderFromStrConfigService::<String, _>::required(
578 inner_service,
579 HeaderName::from_static("x-proxy-labels"),
580 )
581 .with_repeat(true);
582
583 let result = service.serve(Context::default(), request).await;
584 assert!(result.is_err());
585 }
586
587 #[tokio::test]
588 async fn test_header_config_required_invalid_config() {
589 let request = Request::builder()
590 .method(Method::GET)
591 .uri("https://www.example.com")
592 .header("x-proxy-id", "foo")
593 .body(())
594 .unwrap();
595
596 let inner_service =
597 rama_core::service::service_fn(async |_ctx: Context<()>, _req: Request<()>| {
598 Ok::<_, std::convert::Infallible>(())
599 });
600
601 let service = HeaderFromStrConfigService::<usize, _>::required(
602 inner_service,
603 HeaderName::from_static("x-proxy-id"),
604 );
605
606 let result = service.serve(Context::default(), request).await;
607 assert!(result.is_err());
608 }
609
610 #[tokio::test]
611 async fn test_header_config_repeat_required_invalid_config() {
612 let request = Request::builder()
613 .method(Method::GET)
614 .uri("https://www.example.com")
615 .header("x-proxy-labels", "42,foo")
616 .body(())
617 .unwrap();
618
619 let inner_service =
620 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
621 assert!(ctx.get::<Vec<String>>().is_none());
622
623 Ok::<_, std::convert::Infallible>(())
624 });
625
626 let service = HeaderFromStrConfigService::<usize, _>::required(
627 inner_service,
628 HeaderName::from_static("x-proxy-labels"),
629 )
630 .with_repeat(true);
631
632 let result = service.serve(Context::default(), request).await;
633 assert!(result.is_err());
634 }
635
636 #[tokio::test]
637 async fn test_header_config_optional_invalid_config() {
638 let request = Request::builder()
639 .method(Method::GET)
640 .uri("https://www.example.com")
641 .header("x-proxy-id", "foo")
642 .body(())
643 .unwrap();
644
645 let inner_service =
646 rama_core::service::service_fn(async |_ctx: Context<()>, _req: Request<()>| {
647 Ok::<_, std::convert::Infallible>(())
648 });
649
650 let service = HeaderFromStrConfigService::<usize, _>::optional(
651 inner_service,
652 HeaderName::from_static("x-proxy-id"),
653 );
654
655 let result = service.serve(Context::default(), request).await;
656 assert!(result.is_err());
657 }
658
659 #[tokio::test]
660 async fn test_header_config_repeat_optional_invalid_config() {
661 let request = Request::builder()
662 .method(Method::GET)
663 .uri("https://www.example.com")
664 .header("x-proxy-labels", "42,foo")
665 .body(())
666 .unwrap();
667
668 let inner_service =
669 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
670 assert!(ctx.get::<Vec<String>>().is_none());
671
672 Ok::<_, std::convert::Infallible>(())
673 });
674
675 let service = HeaderFromStrConfigService::<usize, _>::optional(
676 inner_service,
677 HeaderName::from_static("x-proxy-labels"),
678 )
679 .with_repeat(true);
680
681 let result = service.serve(Context::default(), request).await;
682 assert!(result.is_err());
683 }
684}