rama_http/layer/forwarded/
set_forwarded_multi.rs1use crate::Request;
2use crate::headers::HeaderMapExt;
3use crate::headers::forwarded::ForwardHeader;
4use rama_core::error::BoxError;
5use rama_core::{Context, Layer, Service};
6use rama_net::address::Domain;
7use rama_net::forwarded::{Forwarded, ForwardedElement, NodeId};
8use rama_net::http::RequestContext;
9use rama_net::stream::SocketInfo;
10use rama_utils::macros::all_the_tuples_no_last_special_case;
11use std::fmt;
12use std::marker::PhantomData;
13
14pub struct SetForwardedHeadersLayer<T = Forwarded> {
23 by_node: NodeId,
24 _headers: PhantomData<fn() -> T>,
25}
26
27impl<T: fmt::Debug> fmt::Debug for SetForwardedHeadersLayer<T> {
28 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
29 f.debug_struct("SetForwardedHeadersLayer")
30 .field("by_node", &self.by_node)
31 .field(
32 "_headers",
33 &format_args!("{}", std::any::type_name::<fn() -> T>()),
34 )
35 .finish()
36 }
37}
38
39impl<T: Clone> Clone for SetForwardedHeadersLayer<T> {
40 fn clone(&self) -> Self {
41 Self {
42 by_node: self.by_node.clone(),
43 _headers: PhantomData,
44 }
45 }
46}
47
48impl<T> Default for SetForwardedHeadersLayer<T> {
49 #[inline]
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl<T> SetForwardedHeadersLayer<T> {
56 pub fn new() -> Self {
58 Self {
59 by_node: Domain::from_static("rama").into(),
60 _headers: PhantomData,
61 }
62 }
63}
64
65impl<H, S> Layer<S> for SetForwardedHeadersLayer<H> {
66 type Service = SetForwardedHeadersService<S, H>;
67
68 fn layer(&self, inner: S) -> Self::Service {
69 Self::Service {
70 inner,
71 by_node: self.by_node.clone(),
72 _headers: PhantomData,
73 }
74 }
75
76 fn into_layer(self, inner: S) -> Self::Service {
77 Self::Service {
78 inner,
79 by_node: self.by_node,
80 _headers: PhantomData,
81 }
82 }
83}
84
85pub struct SetForwardedHeadersService<S, T = Forwarded> {
90 inner: S,
91 by_node: NodeId,
92 _headers: PhantomData<fn() -> T>,
93}
94
95impl<S: fmt::Debug, T> fmt::Debug for SetForwardedHeadersService<S, T> {
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 f.debug_struct("SetForwardedHeadersService")
98 .field("inner", &self.inner)
99 .field("by_node", &self.by_node)
100 .field(
101 "_headers",
102 &format_args!("{}", std::any::type_name::<fn() -> T>()),
103 )
104 .finish()
105 }
106}
107
108impl<S: Clone, T> Clone for SetForwardedHeadersService<S, T> {
109 fn clone(&self) -> Self {
110 SetForwardedHeadersService {
111 inner: self.inner.clone(),
112 by_node: self.by_node.clone(),
113 _headers: PhantomData,
114 }
115 }
116}
117
118impl<S, T> SetForwardedHeadersService<S, T> {
119 pub fn new(inner: S) -> Self {
121 Self {
122 inner,
123 by_node: Domain::from_static("rama").into(),
124 _headers: PhantomData,
125 }
126 }
127}
128
129macro_rules! set_forwarded_service_for_tuple {
130 ( $($ty:ident),* $(,)? ) => {
131 #[allow(non_snake_case)]
132 impl<S, $($ty),* , State, Body> Service<State, Request<Body>> for SetForwardedHeadersService<S, ($($ty,)*)>
133 where
134 $( $ty: ForwardHeader + Send + Sync + 'static, )*
135 S: Service<State, Request<Body>, Error: Into<BoxError>>,
136 Body: Send + 'static,
137 State: Clone + Send + Sync + 'static,
138 {
139 type Response = S::Response;
140 type Error = BoxError;
141
142 async fn serve(
143 &self,
144 mut ctx: Context<State>,
145 mut req: Request<Body>,
146 ) -> Result<Self::Response, Self::Error> {
147 let forwarded: Option<Forwarded> = ctx.get().cloned();
148
149 let mut forwarded_element = ForwardedElement::forwarded_by(self.by_node.clone());
150
151 if let Some(peer_addr) = ctx.get::<SocketInfo>().map(|socket| *socket.peer_addr()) {
152 forwarded_element.set_forwarded_for(peer_addr);
153 }
154
155 let request_ctx: &mut RequestContext =
156 ctx.get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())?;
157
158 forwarded_element.set_forwarded_host(request_ctx.authority.clone());
159
160 if let Ok(forwarded_proto) = (&request_ctx.protocol).try_into() {
161 forwarded_element.set_forwarded_proto(forwarded_proto);
162 }
163
164 let forwarded = match forwarded {
165 None => Some(Forwarded::new(forwarded_element)),
166 Some(mut forwarded) => {
167 forwarded.append(forwarded_element);
168 Some(forwarded)
169 }
170 };
171
172 if let Some(forwarded) = forwarded {
173 $(
174 if let Some(header) = $ty::try_from_forwarded(forwarded.iter()) {
175 req.headers_mut().typed_insert(header);
176 }
177 )*
178 }
179
180 self.inner.serve(ctx, req).await.map_err(Into::into)
181 }
182 }
183 };
184}
185all_the_tuples_no_last_special_case!(set_forwarded_service_for_tuple);
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::{
191 Response, StatusCode,
192 headers::forwarded::{TrueClientIp, XClientIp, XRealIp},
193 service::web::response::IntoResponse,
194 };
195 use rama_core::{Layer, error::OpaqueError, service::service_fn};
196 use rama_http_headers::forwarded::XForwardedProto;
197 use std::convert::Infallible;
198
199 fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
200
201 async fn dummy_service_fn() -> Result<Response, OpaqueError> {
202 Ok(StatusCode::OK.into_response())
203 }
204
205 #[test]
206 fn test_set_forwarded_service_is_service() {
207 assert_is_service(SetForwardedHeadersService::<_, (TrueClientIp,)>::new(
208 service_fn(dummy_service_fn),
209 ));
210 assert_is_service(
211 SetForwardedHeadersService::<_, (TrueClientIp, XClientIp)>::new(service_fn(
212 dummy_service_fn,
213 )),
214 );
215 assert_is_service(
216 SetForwardedHeadersLayer::<(XRealIp, XForwardedProto)>::new()
217 .into_layer(service_fn(dummy_service_fn)),
218 );
219 }
220
221 #[tokio::test]
222 async fn test_set_forwarded_service_forwarded() {
223 async fn svc(request: Request<()>) -> Result<(), Infallible> {
224 assert_eq!(
225 request.headers().get("Forwarded").unwrap(),
226 "by=rama;host=\"example.com:80\";proto=http"
227 );
228 Ok(())
229 }
230
231 let service =
232 SetForwardedHeadersService::<_, (rama_http_headers::forwarded::Forwarded,)>::new(
233 service_fn(svc),
234 );
235 let req = Request::builder().uri("example.com").body(()).unwrap();
236 service.serve(Context::default(), req).await.unwrap();
237 }
238}