1use crate::{HeaderName, header::AsHeaderName};
47use crate::{
48 Request,
49 utils::{HeaderValueErr, HeaderValueGetter},
50};
51use rama_core::{Context, Layer, Service, error::BoxError};
52use rama_utils::macros::define_inner_service_accessors;
53use serde::de::DeserializeOwned;
54use std::{fmt, marker::PhantomData};
55
56pub fn extract_header_config<H, T, G>(request: &G, header_name: H) -> Result<T, HeaderValueErr>
58where
59 H: AsHeaderName + Copy,
60 T: DeserializeOwned + Clone + Send + Sync + 'static,
61 G: HeaderValueGetter,
62{
63 let value = request.header_str(header_name)?;
64 let config = serde_html_form::from_str::<T>(value)
65 .map_err(|_| HeaderValueErr::HeaderInvalid(header_name.as_str().to_owned()))?;
66 Ok(config)
67}
68
69pub struct HeaderConfigService<T, S> {
74 inner: S,
75 header_name: HeaderName,
76 optional: bool,
77 _marker: PhantomData<fn() -> T>,
78}
79
80impl<T, S> HeaderConfigService<T, S> {
81 pub const fn new(inner: S, header_name: HeaderName, optional: bool) -> Self {
86 Self {
87 inner,
88 header_name,
89 optional,
90 _marker: PhantomData,
91 }
92 }
93
94 define_inner_service_accessors!();
95
96 pub const fn required(inner: S, header_name: HeaderName) -> Self {
100 Self::new(inner, header_name, false)
101 }
102
103 pub const fn optional(inner: S, header_name: HeaderName) -> Self {
107 Self::new(inner, header_name, true)
108 }
109}
110
111impl<T, S: fmt::Debug> fmt::Debug for HeaderConfigService<T, S> {
112 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
113 f.debug_struct("HeaderConfigService")
114 .field("inner", &self.inner)
115 .field("header_name", &self.header_name)
116 .field("optional", &self.optional)
117 .field(
118 "_marker",
119 &format_args!("{}", std::any::type_name::<fn() -> T>()),
120 )
121 .finish()
122 }
123}
124
125impl<T, S> Clone for HeaderConfigService<T, S>
126where
127 S: Clone,
128{
129 fn clone(&self) -> Self {
130 Self {
131 inner: self.inner.clone(),
132 header_name: self.header_name.clone(),
133 optional: self.optional,
134 _marker: PhantomData,
135 }
136 }
137}
138
139impl<T, S, State, Body, E> Service<State, Request<Body>> for HeaderConfigService<T, S>
140where
141 S: Service<State, Request<Body>, Error = E>,
142 T: DeserializeOwned + Clone + Send + Sync + 'static,
143 State: Clone + Send + Sync + 'static,
144 Body: Send + Sync + 'static,
145 E: Into<BoxError> + Send + Sync + 'static,
146{
147 type Response = S::Response;
148 type Error = BoxError;
149
150 async fn serve(
151 &self,
152 mut ctx: Context<State>,
153 request: Request<Body>,
154 ) -> Result<Self::Response, Self::Error> {
155 let config = match extract_header_config::<_, T, _>(&request, &self.header_name) {
156 Ok(config) => config,
157 Err(err) => {
158 if self.optional && matches!(err, crate::utils::HeaderValueErr::HeaderMissing(_)) {
159 tracing::debug!(error = %err, "failed to extract header config");
160 return self.inner.serve(ctx, request).await.map_err(Into::into);
161 } else {
162 return Err(err.into());
163 }
164 }
165 };
166 ctx.insert(config);
167 self.inner.serve(ctx, request).await.map_err(Into::into)
168 }
169}
170
171pub struct HeaderConfigLayer<T> {
176 header_name: HeaderName,
177 optional: bool,
178 _marker: PhantomData<fn() -> T>,
179}
180
181impl<T> fmt::Debug for HeaderConfigLayer<T> {
182 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
183 f.debug_struct("HeaderConfigLayer")
184 .field("header_name", &self.header_name)
185 .field("optional", &self.optional)
186 .field(
187 "_marker",
188 &format_args!("{}", std::any::type_name::<fn() -> T>()),
189 )
190 .finish()
191 }
192}
193
194impl<T> Clone for HeaderConfigLayer<T> {
195 fn clone(&self) -> Self {
196 Self {
197 header_name: self.header_name.clone(),
198 optional: self.optional,
199 _marker: PhantomData,
200 }
201 }
202}
203
204impl<T> HeaderConfigLayer<T> {
205 pub fn required(header_name: HeaderName) -> Self {
209 Self {
210 header_name,
211 optional: false,
212 _marker: PhantomData,
213 }
214 }
215
216 pub fn optional(header_name: HeaderName) -> Self {
220 Self {
221 header_name,
222 optional: true,
223 _marker: PhantomData,
224 }
225 }
226}
227
228impl<T, S> Layer<S> for HeaderConfigLayer<T> {
229 type Service = HeaderConfigService<T, S>;
230
231 fn layer(&self, inner: S) -> Self::Service {
232 HeaderConfigService::new(inner, self.header_name.clone(), self.optional)
233 }
234
235 fn into_layer(self, inner: S) -> Self::Service {
236 HeaderConfigService::new(inner, self.header_name, self.optional)
237 }
238}
239
240#[cfg(test)]
241mod test {
242 use serde::Deserialize;
243
244 use crate::Method;
245
246 use super::*;
247
248 #[tokio::test]
249 async fn test_header_config_required_happy_path() {
250 let request = Request::builder()
251 .method(Method::GET)
252 .uri("https://www.example.com")
253 .header("x-proxy-config", "s=E%26G&n=1&b=true")
254 .body(())
255 .unwrap();
256
257 let inner_service =
258 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
259 let cfg: &Config = ctx.get().unwrap();
260 assert_eq!(cfg.s, "E&G");
261 assert_eq!(cfg.n, 1);
262 assert!(cfg.m.is_none());
263 assert!(cfg.b);
264
265 Ok::<_, std::convert::Infallible>(())
266 });
267
268 let service = HeaderConfigService::<Config, _>::required(
269 inner_service,
270 HeaderName::from_static("x-proxy-config"),
271 );
272
273 service.serve(Context::default(), request).await.unwrap();
274 }
275
276 #[tokio::test]
277 async fn test_header_config_optional_found() {
278 let request = Request::builder()
279 .method(Method::GET)
280 .uri("https://www.example.com")
281 .header("x-proxy-config", "s=E%26G&n=1&b=true")
282 .body(())
283 .unwrap();
284
285 let inner_service =
286 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
287 let cfg: &Config = ctx.get().unwrap();
288 assert_eq!(cfg.s, "E&G");
289 assert_eq!(cfg.n, 1);
290 assert!(cfg.m.is_none());
291 assert!(cfg.b);
292
293 Ok::<_, std::convert::Infallible>(())
294 });
295
296 let service = HeaderConfigService::<Config, _>::optional(
297 inner_service,
298 HeaderName::from_static("x-proxy-config"),
299 );
300
301 service.serve(Context::default(), request).await.unwrap();
302 }
303
304 #[tokio::test]
305 async fn test_header_config_optional_missing() {
306 let request = Request::builder()
307 .method(Method::GET)
308 .uri("https://www.example.com")
309 .body(())
310 .unwrap();
311
312 let inner_service =
313 rama_core::service::service_fn(async |ctx: Context<()>, _req: Request<()>| {
314 assert!(ctx.get::<Config>().is_none());
315
316 Ok::<_, std::convert::Infallible>(())
317 });
318
319 let service = HeaderConfigService::<Config, _>::optional(
320 inner_service,
321 HeaderName::from_static("x-proxy-config"),
322 );
323
324 service.serve(Context::default(), request).await.unwrap();
325 }
326
327 #[tokio::test]
328 async fn test_header_config_required_missing_header() {
329 let request = Request::builder()
330 .method(Method::GET)
331 .uri("https://www.example.com")
332 .body(())
333 .unwrap();
334
335 let inner_service = rama_core::service::service_fn(async |_: Request<()>| {
336 Ok::<_, std::convert::Infallible>(())
337 });
338
339 let service = HeaderConfigService::<Config, _>::required(
340 inner_service,
341 HeaderName::from_static("x-proxy-config"),
342 );
343
344 let result = service.serve(Context::default(), request).await;
345 assert!(result.is_err());
346 }
347
348 #[tokio::test]
349 async fn test_header_config_required_invalid_config() {
350 let request = Request::builder()
351 .method(Method::GET)
352 .uri("https://www.example.com")
353 .header("x-proxy-config", "s=bar&n=1&b=invalid")
354 .body(())
355 .unwrap();
356
357 let inner_service = rama_core::service::service_fn(async |_: Request<()>| {
358 Ok::<_, std::convert::Infallible>(())
359 });
360
361 let service = HeaderConfigService::<Config, _>::required(
362 inner_service,
363 HeaderName::from_static("x-proxy-config"),
364 );
365
366 let result = service.serve(Context::default(), request).await;
367 assert!(result.is_err());
368 }
369
370 #[tokio::test]
371 async fn test_header_config_optional_invalid_config() {
372 let request = Request::builder()
373 .method(Method::GET)
374 .uri("https://www.example.com")
375 .header("x-proxy-config", "s=bar&n=1&b=invalid")
376 .body(())
377 .unwrap();
378
379 let inner_service = rama_core::service::service_fn(async |_: Request<()>| {
380 Ok::<_, std::convert::Infallible>(())
381 });
382
383 let service = HeaderConfigService::<Config, _>::optional(
384 inner_service,
385 HeaderName::from_static("x-proxy-config"),
386 );
387
388 let result = service.serve(Context::default(), request).await;
389 assert!(result.is_err());
390 }
391
392 #[derive(Debug, Deserialize, Clone)]
393 struct Config {
394 s: String,
395 n: i32,
396 m: Option<i32>,
397 b: bool,
398 }
399}