1use crate::{header::AsHeaderName, HeaderName};
47use crate::{
48 utils::{HeaderValueErr, HeaderValueGetter},
49 Request,
50};
51use rama_core::{error::BoxError, Context, Layer, Service};
52use rama_utils::macros::define_inner_service_accessors;
53use serde::de::DeserializeOwned;
54use std::{fmt, marker::PhantomData};
55
56pub fn extract_header_config<H, T, G>(request: &G, header_name: H) -> Result<T, HeaderValueErr>
58where
59 H: AsHeaderName + Copy,
60 T: DeserializeOwned + Clone + Send + Sync + 'static,
61 G: HeaderValueGetter,
62{
63 let value = request.header_str(header_name)?;
64 let config = serde_html_form::from_str::<T>(value)
65 .map_err(|_| HeaderValueErr::HeaderInvalid(header_name.as_str().to_owned()))?;
66 Ok(config)
67}
68
69pub struct HeaderConfigService<T, S> {
74 inner: S,
75 header_name: HeaderName,
76 optional: bool,
77 _marker: PhantomData<fn() -> T>,
78}
79
80impl<T, S> HeaderConfigService<T, S> {
81 pub const fn new(inner: S, header_name: HeaderName, optional: bool) -> Self {
86 Self {
87 inner,
88 header_name,
89 optional,
90 _marker: PhantomData,
91 }
92 }
93
94 define_inner_service_accessors!();
95
96 pub const fn required(inner: S, header_name: HeaderName) -> Self {
100 Self::new(inner, header_name, false)
101 }
102
103 pub const fn optional(inner: S, header_name: HeaderName) -> Self {
107 Self::new(inner, header_name, true)
108 }
109}
110
111impl<T, S: fmt::Debug> fmt::Debug for HeaderConfigService<T, S> {
112 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
113 f.debug_struct("HeaderConfigService")
114 .field("inner", &self.inner)
115 .field("header_name", &self.header_name)
116 .field("optional", &self.optional)
117 .field(
118 "_marker",
119 &format_args!("{}", std::any::type_name::<fn() -> T>()),
120 )
121 .finish()
122 }
123}
124
125impl<T, S> Clone for HeaderConfigService<T, S>
126where
127 S: Clone,
128{
129 fn clone(&self) -> Self {
130 Self {
131 inner: self.inner.clone(),
132 header_name: self.header_name.clone(),
133 optional: self.optional,
134 _marker: PhantomData,
135 }
136 }
137}
138
139impl<T, S, State, Body, E> Service<State, Request<Body>> for HeaderConfigService<T, S>
140where
141 S: Service<State, Request<Body>, Error = E>,
142 T: DeserializeOwned + Clone + Send + Sync + 'static,
143 State: Clone + Send + Sync + 'static,
144 Body: Send + Sync + 'static,
145 E: Into<BoxError> + Send + Sync + 'static,
146{
147 type Response = S::Response;
148 type Error = BoxError;
149
150 async fn serve(
151 &self,
152 mut ctx: Context<State>,
153 request: Request<Body>,
154 ) -> Result<Self::Response, Self::Error> {
155 let config = match extract_header_config::<_, T, _>(&request, &self.header_name) {
156 Ok(config) => config,
157 Err(err) => {
158 if self.optional && matches!(err, crate::utils::HeaderValueErr::HeaderMissing(_)) {
159 tracing::debug!(error = %err, "failed to extract header config");
160 return self.inner.serve(ctx, request).await.map_err(Into::into);
161 } else {
162 return Err(err.into());
163 }
164 }
165 };
166 ctx.insert(config);
167 self.inner.serve(ctx, request).await.map_err(Into::into)
168 }
169}
170
171pub struct HeaderConfigLayer<T> {
176 header_name: HeaderName,
177 optional: bool,
178 _marker: PhantomData<fn() -> T>,
179}
180
181impl<T> fmt::Debug for HeaderConfigLayer<T> {
182 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
183 f.debug_struct("HeaderConfigLayer")
184 .field("header_name", &self.header_name)
185 .field("optional", &self.optional)
186 .field(
187 "_marker",
188 &format_args!("{}", std::any::type_name::<fn() -> T>()),
189 )
190 .finish()
191 }
192}
193
194impl<T> Clone for HeaderConfigLayer<T> {
195 fn clone(&self) -> Self {
196 Self {
197 header_name: self.header_name.clone(),
198 optional: self.optional,
199 _marker: PhantomData,
200 }
201 }
202}
203
204impl<T> HeaderConfigLayer<T> {
205 pub fn required(header_name: HeaderName) -> Self {
209 Self {
210 header_name,
211 optional: false,
212 _marker: PhantomData,
213 }
214 }
215
216 pub fn optional(header_name: HeaderName) -> Self {
220 Self {
221 header_name,
222 optional: true,
223 _marker: PhantomData,
224 }
225 }
226}
227
228impl<T, S> Layer<S> for HeaderConfigLayer<T> {
229 type Service = HeaderConfigService<T, S>;
230
231 fn layer(&self, inner: S) -> Self::Service {
232 HeaderConfigService::new(inner, self.header_name.clone(), self.optional)
233 }
234}
235
236#[cfg(test)]
237mod test {
238 use serde::Deserialize;
239
240 use crate::Method;
241
242 use super::*;
243
244 #[tokio::test]
245 async fn test_header_config_required_happy_path() {
246 let request = Request::builder()
247 .method(Method::GET)
248 .uri("https://www.example.com")
249 .header("x-proxy-config", "s=E%26G&n=1&b=true")
250 .body(())
251 .unwrap();
252
253 let inner_service =
254 rama_core::service::service_fn(|ctx: Context<()>, _req: Request<()>| async move {
255 let cfg: &Config = ctx.get().unwrap();
256 assert_eq!(cfg.s, "E&G");
257 assert_eq!(cfg.n, 1);
258 assert!(cfg.m.is_none());
259 assert!(cfg.b);
260
261 Ok::<_, std::convert::Infallible>(())
262 });
263
264 let service = HeaderConfigService::<Config, _>::required(
265 inner_service,
266 HeaderName::from_static("x-proxy-config"),
267 );
268
269 service.serve(Context::default(), request).await.unwrap();
270 }
271
272 #[tokio::test]
273 async fn test_header_config_optional_found() {
274 let request = Request::builder()
275 .method(Method::GET)
276 .uri("https://www.example.com")
277 .header("x-proxy-config", "s=E%26G&n=1&b=true")
278 .body(())
279 .unwrap();
280
281 let inner_service =
282 rama_core::service::service_fn(|ctx: Context<()>, _req: Request<()>| async move {
283 let cfg: &Config = ctx.get().unwrap();
284 assert_eq!(cfg.s, "E&G");
285 assert_eq!(cfg.n, 1);
286 assert!(cfg.m.is_none());
287 assert!(cfg.b);
288
289 Ok::<_, std::convert::Infallible>(())
290 });
291
292 let service = HeaderConfigService::<Config, _>::optional(
293 inner_service,
294 HeaderName::from_static("x-proxy-config"),
295 );
296
297 service.serve(Context::default(), request).await.unwrap();
298 }
299
300 #[tokio::test]
301 async fn test_header_config_optional_missing() {
302 let request = Request::builder()
303 .method(Method::GET)
304 .uri("https://www.example.com")
305 .body(())
306 .unwrap();
307
308 let inner_service =
309 rama_core::service::service_fn(|ctx: Context<()>, _req: Request<()>| async move {
310 assert!(ctx.get::<Config>().is_none());
311
312 Ok::<_, std::convert::Infallible>(())
313 });
314
315 let service = HeaderConfigService::<Config, _>::optional(
316 inner_service,
317 HeaderName::from_static("x-proxy-config"),
318 );
319
320 service.serve(Context::default(), request).await.unwrap();
321 }
322
323 #[tokio::test]
324 async fn test_header_config_required_missing_header() {
325 let request = Request::builder()
326 .method(Method::GET)
327 .uri("https://www.example.com")
328 .body(())
329 .unwrap();
330
331 let inner_service = rama_core::service::service_fn(|_: Request<()>| async move {
332 Ok::<_, std::convert::Infallible>(())
333 });
334
335 let service = HeaderConfigService::<Config, _>::required(
336 inner_service,
337 HeaderName::from_static("x-proxy-config"),
338 );
339
340 let result = service.serve(Context::default(), request).await;
341 assert!(result.is_err());
342 }
343
344 #[tokio::test]
345 async fn test_header_config_required_invalid_config() {
346 let request = Request::builder()
347 .method(Method::GET)
348 .uri("https://www.example.com")
349 .header("x-proxy-config", "s=bar&n=1&b=invalid")
350 .body(())
351 .unwrap();
352
353 let inner_service = rama_core::service::service_fn(|_: Request<()>| async move {
354 Ok::<_, std::convert::Infallible>(())
355 });
356
357 let service = HeaderConfigService::<Config, _>::required(
358 inner_service,
359 HeaderName::from_static("x-proxy-config"),
360 );
361
362 let result = service.serve(Context::default(), request).await;
363 assert!(result.is_err());
364 }
365
366 #[tokio::test]
367 async fn test_header_config_optional_invalid_config() {
368 let request = Request::builder()
369 .method(Method::GET)
370 .uri("https://www.example.com")
371 .header("x-proxy-config", "s=bar&n=1&b=invalid")
372 .body(())
373 .unwrap();
374
375 let inner_service = rama_core::service::service_fn(|_: Request<()>| async move {
376 Ok::<_, std::convert::Infallible>(())
377 });
378
379 let service = HeaderConfigService::<Config, _>::optional(
380 inner_service,
381 HeaderName::from_static("x-proxy-config"),
382 );
383
384 let result = service.serve(Context::default(), request).await;
385 assert!(result.is_err());
386 }
387
388 #[derive(Debug, Deserialize, Clone)]
389 struct Config {
390 s: String,
391 n: i32,
392 m: Option<i32>,
393 b: bool,
394 }
395}