rama_http/layer/forwarded/
set_forwarded.rs1use crate::Request;
2use crate::headers::HeaderMapExt;
3use crate::headers::forwarded::{
4 ForwardHeader, Via, XForwardedFor, XForwardedHost, XForwardedProto,
5};
6use rama_core::error::BoxError;
7use rama_core::{Context, Layer, Service};
8use rama_http_headers::forwarded::Forwarded;
9use rama_net::address::Domain;
10use rama_net::forwarded::{ForwardedElement, NodeId};
11use rama_net::http::RequestContext;
12use rama_net::stream::SocketInfo;
13use std::fmt;
14use std::marker::PhantomData;
15
16pub struct SetForwardedHeaderLayer<T = Forwarded> {
87 by_node: NodeId,
88 _headers: PhantomData<fn() -> T>,
89}
90
91impl<T: fmt::Debug> fmt::Debug for SetForwardedHeaderLayer<T> {
92 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
93 f.debug_struct("SetForwardedHeaderLayer")
94 .field("by_node", &self.by_node)
95 .field(
96 "_headers",
97 &format_args!("{}", std::any::type_name::<fn() -> T>()),
98 )
99 .finish()
100 }
101}
102
103impl<T: Clone> Clone for SetForwardedHeaderLayer<T> {
104 fn clone(&self) -> Self {
105 Self {
106 by_node: self.by_node.clone(),
107 _headers: PhantomData,
108 }
109 }
110}
111
112impl<T> SetForwardedHeaderLayer<T> {
113 pub fn forward_by(mut self, node_id: impl Into<NodeId>) -> Self {
117 self.by_node = node_id.into();
118 self
119 }
120
121 pub fn set_forward_by(&mut self, node_id: impl Into<NodeId>) -> &mut Self {
125 self.by_node = node_id.into();
126 self
127 }
128}
129
130impl<T> SetForwardedHeaderLayer<T> {
131 pub fn new() -> Self {
133 Self {
134 by_node: Domain::from_static("rama").into(),
135 _headers: PhantomData,
136 }
137 }
138}
139
140impl Default for SetForwardedHeaderLayer {
141 fn default() -> Self {
142 Self::forwarded()
143 }
144}
145
146impl SetForwardedHeaderLayer {
147 #[inline]
148 pub fn forwarded() -> Self {
150 Self::new()
151 }
152}
153
154impl SetForwardedHeaderLayer<Via> {
155 #[inline]
156 pub fn via() -> Self {
158 Self::new()
159 }
160}
161
162impl SetForwardedHeaderLayer<XForwardedFor> {
163 #[inline]
164 pub fn x_forwarded_for() -> Self {
166 Self::new()
167 }
168}
169
170impl SetForwardedHeaderLayer<XForwardedHost> {
171 #[inline]
172 pub fn x_forwarded_host() -> Self {
174 Self::new()
175 }
176}
177
178impl SetForwardedHeaderLayer<XForwardedProto> {
179 #[inline]
180 pub fn x_forwarded_proto() -> Self {
182 Self::new()
183 }
184}
185
186impl<H, S> Layer<S> for SetForwardedHeaderLayer<H> {
187 type Service = SetForwardedHeaderService<S, H>;
188
189 fn layer(&self, inner: S) -> Self::Service {
190 Self::Service {
191 inner,
192 by_node: self.by_node.clone(),
193 _headers: PhantomData,
194 }
195 }
196
197 fn into_layer(self, inner: S) -> Self::Service {
198 Self::Service {
199 inner,
200 by_node: self.by_node,
201 _headers: PhantomData,
202 }
203 }
204}
205
206pub struct SetForwardedHeaderService<S, T = Forwarded> {
211 inner: S,
212 by_node: NodeId,
213 _headers: PhantomData<fn() -> T>,
214}
215
216impl<S: fmt::Debug, T> fmt::Debug for SetForwardedHeaderService<S, T> {
217 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218 f.debug_struct("SetForwardedHeaderService")
219 .field("inner", &self.inner)
220 .field("by_node", &self.by_node)
221 .field(
222 "_headers",
223 &format_args!("{}", std::any::type_name::<fn() -> T>()),
224 )
225 .finish()
226 }
227}
228
229impl<S: Clone, T> Clone for SetForwardedHeaderService<S, T> {
230 fn clone(&self) -> Self {
231 SetForwardedHeaderService {
232 inner: self.inner.clone(),
233 by_node: self.by_node.clone(),
234 _headers: PhantomData,
235 }
236 }
237}
238
239impl<S, T> SetForwardedHeaderService<S, T> {
240 pub fn forward_by(mut self, node_id: impl Into<NodeId>) -> Self {
244 self.by_node = node_id.into();
245 self
246 }
247
248 pub fn set_forward_by(&mut self, node_id: impl Into<NodeId>) -> &mut Self {
252 self.by_node = node_id.into();
253 self
254 }
255}
256
257impl<S, T> SetForwardedHeaderService<S, T> {
258 pub fn new(inner: S) -> Self {
260 Self {
261 inner,
262 by_node: Domain::from_static("rama").into(),
263 _headers: PhantomData,
264 }
265 }
266}
267
268impl<S> SetForwardedHeaderService<S> {
269 #[inline]
270 pub fn forwarded(inner: S) -> Self {
272 Self::new(inner)
273 }
274}
275
276impl<S> SetForwardedHeaderService<S, Via> {
277 #[inline]
278 pub fn via(inner: S) -> Self {
280 Self::new(inner)
281 }
282}
283
284impl<S> SetForwardedHeaderService<S, XForwardedFor> {
285 #[inline]
286 pub fn x_forwarded_for(inner: S) -> Self {
288 Self::new(inner)
289 }
290}
291
292impl<S> SetForwardedHeaderService<S, XForwardedHost> {
293 #[inline]
294 pub fn x_forwarded_host(inner: S) -> Self {
296 Self::new(inner)
297 }
298}
299
300impl<S> SetForwardedHeaderService<S, XForwardedProto> {
301 #[inline]
302 pub fn x_forwarded_proto(inner: S) -> Self {
304 Self::new(inner)
305 }
306}
307
308impl<S, H, State, Body> Service<State, Request<Body>> for SetForwardedHeaderService<S, H>
309where
310 S: Service<State, Request<Body>, Error: Into<BoxError>>,
311 H: ForwardHeader + Send + Sync + 'static,
312 Body: Send + 'static,
313 State: Clone + Send + Sync + 'static,
314{
315 type Response = S::Response;
316 type Error = BoxError;
317
318 async fn serve(
319 &self,
320 mut ctx: Context<State>,
321 mut req: Request<Body>,
322 ) -> Result<Self::Response, Self::Error> {
323 let forwarded: Option<rama_net::forwarded::Forwarded> = ctx.get().cloned();
324
325 let mut forwarded_element = ForwardedElement::forwarded_by(self.by_node.clone());
326
327 if let Some(peer_addr) = ctx.get::<SocketInfo>().map(|socket| *socket.peer_addr()) {
328 forwarded_element.set_forwarded_for(peer_addr);
329 }
330 let request_ctx: &mut RequestContext =
331 ctx.get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())?;
332
333 forwarded_element.set_forwarded_host(request_ctx.authority.clone());
334
335 if let Ok(forwarded_proto) = (&request_ctx.protocol).try_into() {
336 forwarded_element.set_forwarded_proto(forwarded_proto);
337 }
338
339 let forwarded = match forwarded {
340 None => Some(rama_net::forwarded::Forwarded::new(forwarded_element)),
341 Some(mut forwarded) => {
342 forwarded.append(forwarded_element);
343 Some(forwarded)
344 }
345 };
346
347 if let Some(forwarded) = forwarded {
348 if let Some(header) = H::try_from_forwarded(forwarded.iter()) {
349 req.headers_mut().typed_insert(header);
350 }
351 }
352
353 self.inner.serve(ctx, req).await.map_err(Into::into)
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use crate::{
361 Response, StatusCode,
362 headers::forwarded::{TrueClientIp, XRealIp},
363 service::web::response::IntoResponse,
364 };
365 use rama_core::{Layer, error::OpaqueError, service::service_fn};
366 use std::{convert::Infallible, net::IpAddr};
367
368 fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
369
370 async fn dummy_service_fn() -> Result<Response, OpaqueError> {
371 Ok(StatusCode::OK.into_response())
372 }
373
374 #[test]
375 fn test_set_forwarded_service_is_service() {
376 assert_is_service(SetForwardedHeaderService::forwarded(service_fn(
377 dummy_service_fn,
378 )));
379 assert_is_service(SetForwardedHeaderService::via(service_fn(dummy_service_fn)));
380 assert_is_service(SetForwardedHeaderService::x_forwarded_for(service_fn(
381 dummy_service_fn,
382 )));
383 assert_is_service(SetForwardedHeaderService::x_forwarded_proto(service_fn(
384 dummy_service_fn,
385 )));
386 assert_is_service(SetForwardedHeaderService::x_forwarded_host(service_fn(
387 dummy_service_fn,
388 )));
389 assert_is_service(SetForwardedHeaderService::<_, TrueClientIp>::new(
390 service_fn(dummy_service_fn),
391 ));
392 assert_is_service(SetForwardedHeaderLayer::via().into_layer(service_fn(dummy_service_fn)));
393 assert_is_service(
394 SetForwardedHeaderLayer::<XRealIp>::new().into_layer(service_fn(dummy_service_fn)),
395 );
396 }
397
398 #[tokio::test]
399 async fn test_set_forwarded_service_forwarded() {
400 async fn svc(request: Request<()>) -> Result<(), Infallible> {
401 assert_eq!(
402 request.headers().get("Forwarded").unwrap(),
403 "by=rama;host=\"example.com:80\";proto=http"
404 );
405 Ok(())
406 }
407
408 let service = SetForwardedHeaderService::forwarded(service_fn(svc));
409 let req = Request::builder().uri("example.com").body(()).unwrap();
410 service.serve(Context::default(), req).await.unwrap();
411 }
412
413 #[tokio::test]
414 async fn test_set_forwarded_service_forwarded_with_chain() {
415 async fn svc(request: Request<()>) -> Result<(), Infallible> {
416 assert_eq!(
417 request.headers().get("Forwarded").unwrap(),
418 "for=12.23.34.45,by=rama;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
419 );
420 Ok(())
421 }
422
423 let service = SetForwardedHeaderService::forwarded(service_fn(svc));
424 let req = Request::builder()
425 .uri("https://www.example.com")
426 .body(())
427 .unwrap();
428 let mut ctx = Context::default();
429 ctx.insert(rama_net::forwarded::Forwarded::new(
430 ForwardedElement::forwarded_for(IpAddr::from([12, 23, 34, 45])),
431 ));
432 ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
433 service.serve(ctx, req).await.unwrap();
434 }
435
436 #[tokio::test]
437 async fn test_set_forwarded_service_x_forwarded_for_with_chain() {
438 async fn svc(request: Request<()>) -> Result<(), Infallible> {
439 assert_eq!(
440 request.headers().get("X-Forwarded-For").unwrap(),
441 "12.23.34.45, 127.0.0.1",
442 );
443 Ok(())
444 }
445
446 let service = SetForwardedHeaderService::x_forwarded_for(service_fn(svc));
447 let req = Request::builder()
448 .uri("https://www.example.com")
449 .body(())
450 .unwrap();
451 let mut ctx = Context::default();
452 ctx.insert(rama_net::forwarded::Forwarded::new(
453 ForwardedElement::forwarded_for(IpAddr::from([12, 23, 34, 45])),
454 ));
455 ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
456 service.serve(ctx, req).await.unwrap();
457 }
458
459 #[tokio::test]
460 async fn test_set_forwarded_service_forwarded_fully_defined() {
461 async fn svc(request: Request<()>) -> Result<(), Infallible> {
462 assert_eq!(
463 request.headers().get("Forwarded").unwrap(),
464 "by=12.23.34.45;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
465 );
466 Ok(())
467 }
468
469 let service = SetForwardedHeaderService::forwarded(service_fn(svc))
470 .forward_by(IpAddr::from([12, 23, 34, 45]));
471 let req = Request::builder()
472 .uri("https://www.example.com")
473 .body(())
474 .unwrap();
475 let mut ctx = Context::default();
476 ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
477 service.serve(ctx, req).await.unwrap();
478 }
479
480 #[tokio::test]
481 async fn test_set_forwarded_service_forwarded_fully_defined_with_chain() {
482 async fn svc(request: Request<()>) -> Result<(), Infallible> {
483 assert_eq!(
484 request.headers().get("Forwarded").unwrap(),
485 "by=rama;for=\"127.0.0.1:62345\";host=\"www.example.com:443\";proto=https",
486 );
487 Ok(())
488 }
489
490 let service = SetForwardedHeaderService::forwarded(service_fn(svc));
491 let req = Request::builder()
492 .uri("https://www.example.com")
493 .body(())
494 .unwrap();
495 let mut ctx = Context::default();
496 ctx.insert(SocketInfo::new(None, "127.0.0.1:62345".parse().unwrap()));
497 service.serve(ctx, req).await.unwrap();
498 }
499}