rama_http/layer/forwarded/
get_forwarded.rs1use crate::Request;
2use crate::headers::{
3 ForwardHeader, HeaderMapExt, Via, XForwardedFor, XForwardedHost, XForwardedProto,
4};
5use rama_core::{Context, Layer, Service};
6use rama_net::forwarded::Forwarded;
7use rama_net::forwarded::ForwardedElement;
8use rama_utils::macros::all_the_tuples_no_last_special_case;
9use std::fmt;
10use std::marker::PhantomData;
11
12pub struct GetForwardedHeadersLayer<T = Forwarded> {
84 _headers: PhantomData<fn() -> T>,
85}
86
87impl<T: fmt::Debug> fmt::Debug for GetForwardedHeadersLayer<T> {
88 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
89 f.debug_struct("GetForwardedHeadersLayer")
90 .field(
91 "_headers",
92 &format_args!("{}", std::any::type_name::<fn() -> T>()),
93 )
94 .finish()
95 }
96}
97
98impl<T: Clone> Clone for GetForwardedHeadersLayer<T> {
99 fn clone(&self) -> Self {
100 Self {
101 _headers: PhantomData,
102 }
103 }
104}
105
106impl Default for GetForwardedHeadersLayer {
107 fn default() -> Self {
108 Self::forwarded()
109 }
110}
111
112impl<T> GetForwardedHeadersLayer<T> {
113 pub const fn new() -> Self {
115 Self {
116 _headers: PhantomData,
117 }
118 }
119}
120
121impl GetForwardedHeadersLayer {
122 #[inline]
123 pub fn forwarded() -> Self {
125 Self::new()
126 }
127}
128
129impl GetForwardedHeadersLayer<Via> {
130 #[inline]
131 pub fn via() -> Self {
133 Self::new()
134 }
135}
136
137impl GetForwardedHeadersLayer<XForwardedFor> {
138 #[inline]
139 pub fn x_forwarded_for() -> Self {
141 Self::new()
142 }
143}
144
145impl GetForwardedHeadersLayer<XForwardedHost> {
146 #[inline]
147 pub fn x_forwarded_host() -> Self {
149 Self::new()
150 }
151}
152
153impl GetForwardedHeadersLayer<XForwardedProto> {
154 #[inline]
155 pub fn x_forwarded_proto() -> Self {
157 Self::new()
158 }
159}
160
161impl<H, S> Layer<S> for GetForwardedHeadersLayer<H> {
162 type Service = GetForwardedHeadersService<S, H>;
163
164 fn layer(&self, inner: S) -> Self::Service {
165 Self::Service {
166 inner,
167 _headers: PhantomData,
168 }
169 }
170}
171
172pub struct GetForwardedHeadersService<S, T = Forwarded> {
176 inner: S,
177 _headers: PhantomData<fn() -> T>,
178}
179
180impl<S: fmt::Debug, T> fmt::Debug for GetForwardedHeadersService<S, T> {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 f.debug_struct("GetForwardedHeadersService")
183 .field("inner", &self.inner)
184 .field("_headers", &format_args!("{}", std::any::type_name::<T>()))
185 .finish()
186 }
187}
188
189impl<S: Clone, T> Clone for GetForwardedHeadersService<S, T> {
190 fn clone(&self) -> Self {
191 GetForwardedHeadersService {
192 inner: self.inner.clone(),
193 _headers: PhantomData,
194 }
195 }
196}
197
198impl<S, T> GetForwardedHeadersService<S, T> {
199 pub const fn new(inner: S) -> Self {
201 Self {
202 inner,
203 _headers: PhantomData,
204 }
205 }
206}
207
208impl<S> GetForwardedHeadersService<S> {
209 #[inline]
210 pub fn forwarded(inner: S) -> Self {
212 Self::new(inner)
213 }
214}
215
216impl<S> GetForwardedHeadersService<S, Via> {
217 #[inline]
218 pub fn via(inner: S) -> Self {
220 Self::new(inner)
221 }
222}
223
224impl<S> GetForwardedHeadersService<S, XForwardedFor> {
225 #[inline]
226 pub fn x_forwarded_for(inner: S) -> Self {
228 Self::new(inner)
229 }
230}
231
232impl<S> GetForwardedHeadersService<S, XForwardedHost> {
233 #[inline]
234 pub fn x_forwarded_host(inner: S) -> Self {
236 Self::new(inner)
237 }
238}
239
240impl<S> GetForwardedHeadersService<S, XForwardedProto> {
241 #[inline]
242 pub fn x_forwarded_proto(inner: S) -> Self {
244 Self::new(inner)
245 }
246}
247
248macro_rules! get_forwarded_service_for_tuple {
249 ( $($ty:ident),* $(,)? ) => {
250 #[allow(non_snake_case)]
251 impl<$($ty,)* S, State, Body> Service<State, Request<Body>> for GetForwardedHeadersService<S, ($($ty,)*)>
252 where
253 $( $ty: ForwardHeader + Send + Sync + 'static, )*
254 S: Service<State, Request<Body>>,
255 Body: Send + 'static,
256 State: Clone + Send + Sync + 'static,
257 {
258 type Response = S::Response;
259 type Error = S::Error;
260
261 fn serve(
262 &self,
263 mut ctx: Context<State>,
264 req: Request<Body>,
265 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
266 let mut forwarded_elements: Vec<ForwardedElement> = Vec::with_capacity(1);
267
268 $(
269 if let Some($ty) = req.headers().typed_get::<$ty>() {
270 let mut iter = $ty.into_iter();
271 for element in forwarded_elements.iter_mut() {
272 let other = iter.next();
273 match other {
274 Some(other) => {
275 element.merge(other);
276 }
277 None => break,
278 }
279 }
280 for other in iter {
281 forwarded_elements.push(other);
282 }
283 }
284 )*
285
286 if !forwarded_elements.is_empty() {
287 match ctx.get_mut::<Forwarded>() {
288 Some(ref mut f) => {
289 f.extend(forwarded_elements);
290 }
291 None => {
292 let mut it = forwarded_elements.into_iter();
293 let mut forwarded = Forwarded::new(it.next().unwrap());
294 forwarded.extend(it);
295 ctx.insert(forwarded);
296 }
297 }
298 }
299
300 self.inner.serve(ctx, req)
301 }
302 }
303 }
304}
305
306impl<H, S, State, Body> Service<State, Request<Body>> for GetForwardedHeadersService<S, H>
307where
308 H: ForwardHeader + Send + Sync + 'static,
309 S: Service<State, Request<Body>>,
310 Body: Send + 'static,
311 State: Clone + Send + Sync + 'static,
312{
313 type Response = S::Response;
314 type Error = S::Error;
315
316 fn serve(
317 &self,
318 mut ctx: Context<State>,
319 req: Request<Body>,
320 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
321 let mut forwarded_elements: Vec<ForwardedElement> = Vec::with_capacity(1);
322
323 if let Some(header) = req.headers().typed_get::<H>() {
324 forwarded_elements.extend(header);
325 }
326
327 if !forwarded_elements.is_empty() {
328 match ctx.get_mut::<Forwarded>() {
329 Some(ref mut f) => {
330 f.extend(forwarded_elements);
331 }
332 None => {
333 let mut it = forwarded_elements.into_iter();
334 let mut forwarded = Forwarded::new(it.next().unwrap());
335 forwarded.extend(it);
336 ctx.insert(forwarded);
337 }
338 }
339 }
340
341 self.inner.serve(ctx, req)
342 }
343}
344
345all_the_tuples_no_last_special_case!(get_forwarded_service_for_tuple);
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use crate::{
351 IntoResponse, Response, StatusCode,
352 headers::{ClientIp, TrueClientIp, XClientIp, XRealIp},
353 };
354 use rama_core::{Layer, error::OpaqueError, service::service_fn};
355 use rama_net::forwarded::{ForwardedProtocol, ForwardedVersion};
356 use std::{convert::Infallible, net::IpAddr};
357
358 fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
359
360 async fn dummy_service_fn() -> Result<Response, OpaqueError> {
361 Ok(StatusCode::OK.into_response())
362 }
363
364 #[test]
365 fn test_get_forwarded_service_is_service() {
366 assert_is_service(GetForwardedHeadersService::forwarded(service_fn(
367 dummy_service_fn,
368 )));
369 assert_is_service(GetForwardedHeadersService::via(service_fn(
370 dummy_service_fn,
371 )));
372 assert_is_service(GetForwardedHeadersService::x_forwarded_for(service_fn(
373 dummy_service_fn,
374 )));
375 assert_is_service(GetForwardedHeadersService::x_forwarded_proto(service_fn(
376 dummy_service_fn,
377 )));
378 assert_is_service(GetForwardedHeadersService::x_forwarded_host(service_fn(
379 dummy_service_fn,
380 )));
381 assert_is_service(GetForwardedHeadersService::<_, TrueClientIp>::new(
382 service_fn(dummy_service_fn),
383 ));
384 assert_is_service(GetForwardedHeadersService::<_, (TrueClientIp,)>::new(
385 service_fn(dummy_service_fn),
386 ));
387 assert_is_service(
388 GetForwardedHeadersService::<_, (TrueClientIp, XClientIp)>::new(service_fn(
389 dummy_service_fn,
390 )),
391 );
392 assert_is_service(
393 GetForwardedHeadersLayer::forwarded().into_layer(service_fn(dummy_service_fn)),
394 );
395 assert_is_service(GetForwardedHeadersLayer::via().into_layer(service_fn(dummy_service_fn)));
396 assert_is_service(
397 GetForwardedHeadersLayer::<XRealIp>::new().into_layer(service_fn(dummy_service_fn)),
398 );
399 assert_is_service(
400 GetForwardedHeadersLayer::<(ClientIp, TrueClientIp)>::new()
401 .into_layer(service_fn(dummy_service_fn)),
402 );
403 }
404
405 #[tokio::test]
406 async fn test_get_forwarded_header_forwarded() {
407 let service = GetForwardedHeadersLayer::forwarded().into_layer(service_fn(
408 async |ctx: Context<()>, _| {
409 let forwarded = ctx.get::<Forwarded>().unwrap();
410 assert_eq!(forwarded.client_ip(), Some(IpAddr::from([12, 23, 34, 45])));
411 assert_eq!(forwarded.client_proto(), Some(ForwardedProtocol::HTTP));
412 Ok::<_, Infallible>(())
413 },
414 ));
415
416 let req = Request::builder()
417 .header("Forwarded", "for=\"12.23.34.45:5000\";proto=http")
418 .body(())
419 .unwrap();
420
421 service.serve(Context::default(), req).await.unwrap();
422 }
423
424 #[tokio::test]
425 async fn test_get_forwarded_header_via() {
426 let service =
427 GetForwardedHeadersLayer::via().into_layer(service_fn(async |ctx: Context<()>, _| {
428 let forwarded = ctx.get::<Forwarded>().unwrap();
429 assert!(forwarded.client_ip().is_none());
430 assert_eq!(
431 forwarded.iter().next().unwrap().ref_forwarded_by(),
432 Some(&(IpAddr::from([12, 23, 34, 45]), 5000).into())
433 );
434 assert!(forwarded.client_proto().is_none());
435 assert_eq!(forwarded.client_version(), Some(ForwardedVersion::HTTP_11));
436 Ok::<_, Infallible>(())
437 }));
438
439 let req = Request::builder()
440 .header("Via", "1.1 12.23.34.45:5000")
441 .body(())
442 .unwrap();
443
444 service.serve(Context::default(), req).await.unwrap();
445 }
446
447 #[tokio::test]
448 async fn test_get_forwarded_header_x_forwarded_for() {
449 let service = GetForwardedHeadersLayer::x_forwarded_for().into_layer(service_fn(
450 async |ctx: Context<()>, _| {
451 let forwarded = ctx.get::<Forwarded>().unwrap();
452 assert_eq!(forwarded.client_ip(), Some(IpAddr::from([12, 23, 34, 45])));
453 assert!(forwarded.client_proto().is_none());
454 Ok::<_, Infallible>(())
455 },
456 ));
457
458 let req = Request::builder()
459 .header("X-Forwarded-For", "12.23.34.45, 127.0.0.1")
460 .body(())
461 .unwrap();
462
463 service.serve(Context::default(), req).await.unwrap();
464 }
465}