rama_http/layer/forwarded/
get_forwarded.rs1use crate::Request;
2use crate::headers::forwarded::{
3 ForwardHeader, Via, XForwardedFor, XForwardedHost, XForwardedProto,
4};
5use rama_core::{Context, Layer, Service};
6use rama_http_headers::HeaderMapExt;
7use rama_http_headers::forwarded::Forwarded;
8use rama_net::forwarded::ForwardedElement;
9use std::fmt;
10use std::marker::PhantomData;
11
12pub struct GetForwardedHeaderLayer<T = rama_http_headers::forwarded::Forwarded> {
74 _headers: PhantomData<fn() -> T>,
75}
76
77impl<T: fmt::Debug> fmt::Debug for GetForwardedHeaderLayer<T> {
78 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
79 f.debug_struct("GetForwardedHeaderLayer")
80 .field(
81 "_headers",
82 &format_args!("{}", std::any::type_name::<fn() -> T>()),
83 )
84 .finish()
85 }
86}
87
88impl<T: Clone> Clone for GetForwardedHeaderLayer<T> {
89 fn clone(&self) -> Self {
90 Self {
91 _headers: PhantomData,
92 }
93 }
94}
95
96impl Default for GetForwardedHeaderLayer {
97 fn default() -> Self {
98 Self::forwarded()
99 }
100}
101
102impl<T> GetForwardedHeaderLayer<T> {
103 pub const fn new() -> Self {
105 Self {
106 _headers: PhantomData,
107 }
108 }
109}
110
111impl GetForwardedHeaderLayer {
112 #[inline]
113 pub fn forwarded() -> Self {
115 Self::new()
116 }
117}
118
119impl GetForwardedHeaderLayer<Via> {
120 #[inline]
121 pub fn via() -> Self {
123 Self::new()
124 }
125}
126
127impl GetForwardedHeaderLayer<XForwardedFor> {
128 #[inline]
129 pub fn x_forwarded_for() -> Self {
131 Self::new()
132 }
133}
134
135impl GetForwardedHeaderLayer<XForwardedHost> {
136 #[inline]
137 pub fn x_forwarded_host() -> Self {
139 Self::new()
140 }
141}
142
143impl GetForwardedHeaderLayer<XForwardedProto> {
144 #[inline]
145 pub fn x_forwarded_proto() -> Self {
147 Self::new()
148 }
149}
150
151impl<H, S> Layer<S> for GetForwardedHeaderLayer<H> {
152 type Service = GetForwardedHeaderService<S, H>;
153
154 fn layer(&self, inner: S) -> Self::Service {
155 Self::Service {
156 inner,
157 _headers: PhantomData,
158 }
159 }
160}
161
162pub struct GetForwardedHeaderService<S, T = Forwarded> {
166 inner: S,
167 _headers: PhantomData<fn() -> T>,
168}
169
170impl<S: fmt::Debug, T> fmt::Debug for GetForwardedHeaderService<S, T> {
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 f.debug_struct("GetForwardedHeaderService")
173 .field("inner", &self.inner)
174 .field("_headers", &format_args!("{}", std::any::type_name::<T>()))
175 .finish()
176 }
177}
178
179impl<S: Clone, T> Clone for GetForwardedHeaderService<S, T> {
180 fn clone(&self) -> Self {
181 GetForwardedHeaderService {
182 inner: self.inner.clone(),
183 _headers: PhantomData,
184 }
185 }
186}
187
188impl<S, T> GetForwardedHeaderService<S, T> {
189 pub const fn new(inner: S) -> Self {
191 Self {
192 inner,
193 _headers: PhantomData,
194 }
195 }
196}
197
198impl<S> GetForwardedHeaderService<S> {
199 #[inline]
200 pub fn forwarded(inner: S) -> Self {
202 Self::new(inner)
203 }
204}
205
206impl<S> GetForwardedHeaderService<S, Via> {
207 #[inline]
208 pub fn via(inner: S) -> Self {
210 Self::new(inner)
211 }
212}
213
214impl<S> GetForwardedHeaderService<S, XForwardedFor> {
215 #[inline]
216 pub fn x_forwarded_for(inner: S) -> Self {
218 Self::new(inner)
219 }
220}
221
222impl<S> GetForwardedHeaderService<S, XForwardedHost> {
223 #[inline]
224 pub fn x_forwarded_host(inner: S) -> Self {
226 Self::new(inner)
227 }
228}
229
230impl<S> GetForwardedHeaderService<S, XForwardedProto> {
231 #[inline]
232 pub fn x_forwarded_proto(inner: S) -> Self {
234 Self::new(inner)
235 }
236}
237
238impl<H, S, State, Body> Service<State, Request<Body>> for GetForwardedHeaderService<S, H>
239where
240 H: ForwardHeader + Send + Sync + 'static,
241 S: Service<State, Request<Body>>,
242 Body: Send + 'static,
243 State: Clone + Send + Sync + 'static,
244{
245 type Response = S::Response;
246 type Error = S::Error;
247
248 fn serve(
249 &self,
250 mut ctx: Context<State>,
251 req: Request<Body>,
252 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
253 let mut forwarded_elements: Vec<ForwardedElement> = Vec::with_capacity(1);
254
255 if let Some(header) = req.headers().typed_get::<H>() {
256 forwarded_elements.extend(header);
257 }
258
259 if !forwarded_elements.is_empty() {
260 match ctx.get_mut::<Forwarded>() {
261 Some(ref mut f) => {
262 f.extend(forwarded_elements);
263 }
264 None => {
265 let mut it = forwarded_elements.into_iter();
266 let mut forwarded = rama_net::forwarded::Forwarded::new(it.next().unwrap());
267 forwarded.extend(it);
268 ctx.insert(forwarded);
269 }
270 }
271 }
272
273 self.inner.serve(ctx, req)
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use crate::{Response, StatusCode, service::web::response::IntoResponse};
281 use rama_core::{Layer, error::OpaqueError, service::service_fn};
282 use rama_http_headers::forwarded::{TrueClientIp, XRealIp};
283 use rama_net::forwarded::{ForwardedProtocol, ForwardedVersion};
284 use std::{convert::Infallible, net::IpAddr};
285
286 fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
287
288 async fn dummy_service_fn() -> Result<Response, OpaqueError> {
289 Ok(StatusCode::OK.into_response())
290 }
291
292 #[test]
293 fn test_get_forwarded_service_is_service() {
294 assert_is_service(GetForwardedHeaderService::forwarded(service_fn(
295 dummy_service_fn,
296 )));
297 assert_is_service(GetForwardedHeaderService::via(service_fn(dummy_service_fn)));
298 assert_is_service(GetForwardedHeaderService::x_forwarded_for(service_fn(
299 dummy_service_fn,
300 )));
301 assert_is_service(GetForwardedHeaderService::x_forwarded_proto(service_fn(
302 dummy_service_fn,
303 )));
304 assert_is_service(GetForwardedHeaderService::x_forwarded_host(service_fn(
305 dummy_service_fn,
306 )));
307 assert_is_service(GetForwardedHeaderService::<_, TrueClientIp>::new(
308 service_fn(dummy_service_fn),
309 ));
310 assert_is_service(
311 GetForwardedHeaderLayer::forwarded().into_layer(service_fn(dummy_service_fn)),
312 );
313 assert_is_service(GetForwardedHeaderLayer::via().into_layer(service_fn(dummy_service_fn)));
314 assert_is_service(
315 GetForwardedHeaderLayer::<XRealIp>::new().into_layer(service_fn(dummy_service_fn)),
316 );
317 }
318
319 #[tokio::test]
320 async fn test_get_forwarded_header_forwarded() {
321 let service = GetForwardedHeaderLayer::forwarded().into_layer(service_fn(
322 async |ctx: Context<()>, _| {
323 let forwarded = ctx.get::<rama_net::forwarded::Forwarded>().unwrap();
324 assert_eq!(forwarded.client_ip(), Some(IpAddr::from([12, 23, 34, 45])));
325 assert_eq!(forwarded.client_proto(), Some(ForwardedProtocol::HTTP));
326 Ok::<_, Infallible>(())
327 },
328 ));
329
330 let req = Request::builder()
331 .header("Forwarded", "for=\"12.23.34.45:5000\";proto=http")
332 .body(())
333 .unwrap();
334
335 service.serve(Context::default(), req).await.unwrap();
336 }
337
338 #[tokio::test]
339 async fn test_get_forwarded_header_via() {
340 let service =
341 GetForwardedHeaderLayer::via().into_layer(service_fn(async |ctx: Context<()>, _| {
342 let forwarded = ctx.get::<rama_net::forwarded::Forwarded>().unwrap();
343 assert!(forwarded.client_ip().is_none());
344 assert_eq!(
345 forwarded.iter().next().unwrap().ref_forwarded_by(),
346 Some(&(IpAddr::from([12, 23, 34, 45]), 5000).into())
347 );
348 assert!(forwarded.client_proto().is_none());
349 assert_eq!(forwarded.client_version(), Some(ForwardedVersion::HTTP_11));
350 Ok::<_, Infallible>(())
351 }));
352
353 let req = Request::builder()
354 .header("Via", "1.1 12.23.34.45:5000")
355 .body(())
356 .unwrap();
357
358 service.serve(Context::default(), req).await.unwrap();
359 }
360
361 #[tokio::test]
362 async fn test_get_forwarded_header_x_forwarded_for() {
363 let service = GetForwardedHeaderLayer::x_forwarded_for().into_layer(service_fn(
364 async |ctx: Context<()>, _| {
365 let forwarded = ctx.get::<rama_net::forwarded::Forwarded>().unwrap();
366 assert_eq!(forwarded.client_ip(), Some(IpAddr::from([12, 23, 34, 45])));
367 assert!(forwarded.client_proto().is_none());
368 Ok::<_, Infallible>(())
369 },
370 ));
371
372 let req = Request::builder()
373 .header("X-Forwarded-For", "12.23.34.45, 127.0.0.1")
374 .body(())
375 .unwrap();
376
377 service.serve(Context::default(), req).await.unwrap();
378 }
379}