rama_http/layer/forwarded/
get_forwarded_multi.rs1use crate::Request;
2use crate::headers::forwarded::ForwardHeader;
3use rama_core::{Context, Layer, Service};
4use rama_http_headers::HeaderMapExt;
5use rama_net::forwarded::Forwarded;
6use rama_net::forwarded::ForwardedElement;
7use rama_utils::macros::all_the_tuples_no_last_special_case;
8use std::fmt;
9use std::marker::PhantomData;
10
11pub struct GetForwardedHeadersLayer<T = Forwarded> {
43 _headers: PhantomData<fn() -> T>,
44}
45
46impl<T: fmt::Debug> fmt::Debug for GetForwardedHeadersLayer<T> {
47 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
48 f.debug_struct("GetForwardedHeadersLayer")
49 .field(
50 "_headers",
51 &format_args!("{}", std::any::type_name::<fn() -> T>()),
52 )
53 .finish()
54 }
55}
56
57impl<T: Clone> Clone for GetForwardedHeadersLayer<T> {
58 fn clone(&self) -> Self {
59 Self {
60 _headers: PhantomData,
61 }
62 }
63}
64
65impl<T> Default for GetForwardedHeadersLayer<T> {
66 #[inline]
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72impl<T> GetForwardedHeadersLayer<T> {
73 pub const fn new() -> Self {
75 Self {
76 _headers: PhantomData,
77 }
78 }
79}
80
81impl<H, S> Layer<S> for GetForwardedHeadersLayer<H> {
82 type Service = GetForwardedHeadersService<S, H>;
83
84 fn layer(&self, inner: S) -> Self::Service {
85 Self::Service {
86 inner,
87 _headers: PhantomData,
88 }
89 }
90}
91
92pub struct GetForwardedHeadersService<S, T = Forwarded> {
96 inner: S,
97 _headers: PhantomData<fn() -> T>,
98}
99
100impl<S: fmt::Debug, T> fmt::Debug for GetForwardedHeadersService<S, T> {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 f.debug_struct("GetForwardedHeadersService")
103 .field("inner", &self.inner)
104 .field("_headers", &format_args!("{}", std::any::type_name::<T>()))
105 .finish()
106 }
107}
108
109impl<S: Clone, T> Clone for GetForwardedHeadersService<S, T> {
110 fn clone(&self) -> Self {
111 GetForwardedHeadersService {
112 inner: self.inner.clone(),
113 _headers: PhantomData,
114 }
115 }
116}
117
118impl<S, T> GetForwardedHeadersService<S, T> {
119 pub const fn new(inner: S) -> Self {
121 Self {
122 inner,
123 _headers: PhantomData,
124 }
125 }
126}
127
128macro_rules! get_forwarded_service_for_tuple {
129 ( $($ty:ident),* $(,)? ) => {
130 #[allow(non_snake_case)]
131 impl<$($ty,)* S, State, Body> Service<State, Request<Body>> for GetForwardedHeadersService<S, ($($ty,)*)>
132 where
133 $( $ty: ForwardHeader + Send + Sync + 'static, )*
134 S: Service<State, Request<Body>>,
135 Body: Send + 'static,
136 State: Clone + Send + Sync + 'static,
137 {
138 type Response = S::Response;
139 type Error = S::Error;
140
141 fn serve(
142 &self,
143 mut ctx: Context<State>,
144 req: Request<Body>,
145 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
146 let mut forwarded_elements: Vec<ForwardedElement> = Vec::with_capacity(1);
147
148 $(
149 if let Some($ty) = req.headers().typed_get::<$ty>() {
150 let mut iter = $ty.into_iter();
151 for element in forwarded_elements.iter_mut() {
152 let other = iter.next();
153 match other {
154 Some(other) => {
155 element.merge(other);
156 }
157 None => break,
158 }
159 }
160 for other in iter {
161 forwarded_elements.push(other);
162 }
163 }
164 )*
165
166 if !forwarded_elements.is_empty() {
167 match ctx.get_mut::<Forwarded>() {
168 Some(ref mut f) => {
169 f.extend(forwarded_elements);
170 }
171 None => {
172 let mut it = forwarded_elements.into_iter();
173 let mut forwarded = Forwarded::new(it.next().unwrap());
174 forwarded.extend(it);
175 ctx.insert(forwarded);
176 }
177 }
178 }
179
180 self.inner.serve(ctx, req)
181 }
182 }
183 }
184}
185
186all_the_tuples_no_last_special_case!(get_forwarded_service_for_tuple);
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::{
192 Response, StatusCode,
193 headers::forwarded::{ClientIp, TrueClientIp, XClientIp},
194 service::web::response::IntoResponse,
195 };
196 use rama_core::{Layer, error::OpaqueError, service::service_fn};
197 use rama_net::forwarded::ForwardedProtocol;
198 use std::{convert::Infallible, net::IpAddr};
199
200 fn assert_is_service<T: Service<(), Request<()>>>(_: T) {}
201
202 async fn dummy_service_fn() -> Result<Response, OpaqueError> {
203 Ok(StatusCode::OK.into_response())
204 }
205
206 #[test]
207 fn test_get_forwarded_service_is_service() {
208 assert_is_service(GetForwardedHeadersService::<_, (TrueClientIp,)>::new(
209 service_fn(dummy_service_fn),
210 ));
211 assert_is_service(
212 GetForwardedHeadersService::<_, (TrueClientIp, XClientIp)>::new(service_fn(
213 dummy_service_fn,
214 )),
215 );
216 assert_is_service(
217 GetForwardedHeadersLayer::<(ClientIp, TrueClientIp)>::new()
218 .into_layer(service_fn(dummy_service_fn)),
219 );
220 }
221
222 #[tokio::test]
223 async fn test_get_forwarded_headers() {
224 let service = GetForwardedHeadersLayer::<(rama_http_headers::forwarded::Forwarded,)>::new()
225 .into_layer(service_fn(async |ctx: Context<()>, _| {
226 let forwarded = ctx.get::<Forwarded>().unwrap();
227 assert_eq!(forwarded.client_ip(), Some(IpAddr::from([12, 23, 34, 45])));
228 assert_eq!(forwarded.client_proto(), Some(ForwardedProtocol::HTTP));
229 Ok::<_, Infallible>(())
230 }));
231
232 let req = Request::builder()
233 .header("Forwarded", "for=\"12.23.34.45:5000\";proto=http")
234 .body(())
235 .unwrap();
236
237 service.serve(Context::default(), req).await.unwrap();
238 }
239}