1use crate::forwarded::Forwarded;
2use crate::transport::{TransportContext, TransportProtocol, TryRefIntoTransportContext};
3use crate::{
4 Protocol,
5 address::{Authority, Host},
6};
7use rama_core::Context;
8use rama_core::error::OpaqueError;
9use rama_http_types::Method;
10use rama_http_types::{Request, Uri, Version, dep::http::request::Parts};
11
12#[cfg(feature = "tls")]
13use crate::tls::SecureTransport;
14
15#[cfg(feature = "tls")]
16fn try_get_host_from_secure_transport(t: &SecureTransport) -> Option<Host> {
17 use crate::tls::client::ClientHelloExtension;
18
19 t.client_hello().and_then(|h| {
20 h.extensions().iter().find_map(|e| match e {
21 ClientHelloExtension::ServerName(maybe_host) => maybe_host.clone(),
22 _ => None,
23 })
24 })
25}
26
27#[cfg(not(feature = "tls"))]
28#[derive(Debug, Clone)]
29#[non_exhaustive]
30struct SecureTransport;
31
32#[cfg(not(feature = "tls"))]
33fn try_get_host_from_secure_transport(_: &SecureTransport) -> Option<Host> {
34 None
35}
36
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct RequestContext {
40 pub http_version: Version,
42 pub protocol: Protocol,
44 pub authority: Authority,
53}
54
55impl RequestContext {
56 pub fn authority_has_default_port(&self) -> bool {
58 self.protocol.default_port() == Some(self.authority.port())
59 }
60}
61
62impl<Body, State> TryFrom<(&Context<State>, &Request<Body>)> for RequestContext {
63 type Error = OpaqueError;
64
65 fn try_from((ctx, req): (&Context<State>, &Request<Body>)) -> Result<Self, Self::Error> {
66 let uri = req.uri();
67
68 let protocol = protocol_from_uri_or_context(ctx, uri, req.method());
69 tracing::trace!(
70 uri = %uri, "request context: detected protocol: {protocol} (scheme: {:?})",
71 uri.scheme()
72 );
73
74 let default_port = uri
75 .port_u16()
76 .unwrap_or_else(|| protocol.default_port().unwrap_or(80));
77 tracing::trace!(uri = %uri, "request context: detected default port: {default_port}");
78
79 let authority = match ctx.get().and_then(try_get_host_from_secure_transport) {
80 Some(h) => {
81 tracing::trace!(uri = %uri, host = %h, "request context: detected host from SNI");
82 (h, default_port).into()
83 },
84 None => uri
85 .host()
86 .and_then(|h| Host::try_from(h).ok().map(|h| {
87 tracing::trace!(uri = %uri, host = %h, "request context: detected host from (abs) uri");
88 (h, default_port).into()
89 }))
90 .or_else(|| {
91 ctx.get::<Forwarded>().and_then(|f| {
92 f.client_host().map(|fauth| {
93 let (host, port) = fauth.clone().into_parts();
94 let port = port.unwrap_or(default_port);
95 tracing::trace!(uri = %uri, host = %host, "request context: detected host from forwarded info");
96 (host, port).into()
97 })
98 })
99 })
100 .or_else(|| {
101 req.headers()
102 .get(rama_http_types::header::HOST)
103 .and_then(|host| {
104 host.try_into() .or_else(|_| Host::try_from(host).map(|h| {
106 tracing::trace!(uri = %uri, host = %h, "request context: detected host from host header");
107 (h, default_port).into()
108 }))
109 .ok()
110 })
111 })
112 .ok_or_else(|| {
113 OpaqueError::from_display("RequestContext: no authourity found in http::Request")
114 })?
115 };
116
117 tracing::trace!(uri = %uri, "request context: detected authority: {authority}");
118
119 let http_version = ctx
120 .get::<Forwarded>()
121 .and_then(|f| {
122 f.client_version().map(|v| match v {
123 crate::forwarded::ForwardedVersion::HTTP_09 => Version::HTTP_09,
124 crate::forwarded::ForwardedVersion::HTTP_10 => Version::HTTP_10,
125 crate::forwarded::ForwardedVersion::HTTP_11 => Version::HTTP_11,
126 crate::forwarded::ForwardedVersion::HTTP_2 => Version::HTTP_2,
127 crate::forwarded::ForwardedVersion::HTTP_3 => Version::HTTP_3,
128 })
129 })
130 .unwrap_or_else(|| req.version());
131 tracing::trace!(uri = %uri, "request context: maybe detected http version: {http_version:?}");
132
133 Ok(RequestContext {
134 http_version,
135 protocol,
136 authority,
137 })
138 }
139}
140
141impl<State> TryFrom<(&Context<State>, &Parts)> for RequestContext {
142 type Error = OpaqueError;
143
144 fn try_from((ctx, parts): (&Context<State>, &Parts)) -> Result<Self, Self::Error> {
145 let uri = &parts.uri;
146
147 let protocol = protocol_from_uri_or_context(ctx, uri, &parts.method);
148 tracing::trace!(
149 uri = %uri, "request context: detected protocol: {protocol} (scheme: {:?})",
150 uri.scheme()
151 );
152
153 let default_port = uri
154 .port_u16()
155 .unwrap_or_else(|| protocol.default_port().unwrap_or(80));
156 tracing::trace!(uri = %uri, "request context: detected default port: {default_port}");
157
158 let authority = match ctx.get().and_then(try_get_host_from_secure_transport) {
159 Some(h) => {
160 tracing::trace!(uri = %uri, host = %h, "request context: detected host from SNI");
161 (h, default_port).into()
162 }
163 None => {
164 uri
165 .host()
166 .and_then(|h| Host::try_from(h).ok().map(|h| {
167 tracing::trace!(uri = %uri, host = %h, "request context: detected host from (abs) uri");
168 (h, default_port).into()
169 }))
170 .or_else(|| {
171 ctx.get::<Forwarded>().and_then(|f| {
172 f.client_host().map(|fauth| {
173 let (host, port) = fauth.clone().into_parts();
174 let port = port.unwrap_or(default_port);
175 tracing::trace!(uri = %uri, host = %host, "request context: detected host from forwarded info");
176 (host, port).into()
177 })
178 })
179 })
180 .or_else(|| {
181 parts
182 .headers
183 .get(rama_http_types::header::HOST)
184 .and_then(|host| {
185 host.try_into() .or_else(|_| Host::try_from(host).map(|h| {
187 tracing::trace!(uri = %uri, host = %h, "request context: detected host from host header");
188 (h, default_port).into()
189 }))
190 .ok()
191 })
192 })
193 .ok_or_else(|| {
194 OpaqueError::from_display(
195 "RequestContext: no authourity found in http::request::Parts",
196 )
197 })?
198 }
199 };
200
201 tracing::trace!(uri = %uri, "request context: detected authority: {authority}");
202
203 let http_version = ctx
204 .get::<Forwarded>()
205 .and_then(|f| {
206 f.client_version().map(|v| match v {
207 crate::forwarded::ForwardedVersion::HTTP_09 => Version::HTTP_09,
208 crate::forwarded::ForwardedVersion::HTTP_10 => Version::HTTP_10,
209 crate::forwarded::ForwardedVersion::HTTP_11 => Version::HTTP_11,
210 crate::forwarded::ForwardedVersion::HTTP_2 => Version::HTTP_2,
211 crate::forwarded::ForwardedVersion::HTTP_3 => Version::HTTP_3,
212 })
213 })
214 .unwrap_or(parts.version);
215 tracing::trace!(uri = %uri, "request context: maybe detected http version: {http_version:?}");
216
217 Ok(RequestContext {
218 http_version,
219 protocol,
220 authority,
221 })
222 }
223}
224
225#[allow(clippy::unnecessary_lazy_evaluations)]
226fn protocol_from_uri_or_context<State>(
227 ctx: &Context<State>,
228 uri: &Uri,
229 method: &Method,
230) -> Protocol {
231 Protocol::maybe_from_uri_scheme_str_and_method(uri.scheme(), Some(method)).or_else(|| ctx.get::<Forwarded>()
232 .and_then(|f| f.client_proto().map(|p| {
233 tracing::trace!(uri = %uri, "request context: detected protocol from forwarded client proto");
234 p.into()
235 })))
236 .unwrap_or_else(|| {
237 if method == Method::CONNECT {
238 tracing::trace!(uri = %uri, method = %method, "request context: CONNECT: defaulting protocol to HTTPS");
239 Protocol::HTTPS
240 } else {
241 tracing::trace!(uri = %uri, method = %method, "request context: defaulting protocol to HTTP");
242 Protocol::HTTP
243 }
244 })
245}
246
247impl From<RequestContext> for TransportContext {
248 fn from(value: RequestContext) -> Self {
249 Self {
250 protocol: if value.http_version == Version::HTTP_3 {
251 TransportProtocol::Udp
252 } else {
253 TransportProtocol::Tcp
254 },
255 app_protocol: Some(value.protocol),
256 http_version: Some(value.http_version),
257 authority: value.authority,
258 }
259 }
260}
261
262impl From<&RequestContext> for TransportContext {
263 fn from(value: &RequestContext) -> Self {
264 Self {
265 protocol: if value.http_version == Version::HTTP_3 {
266 TransportProtocol::Udp
267 } else {
268 TransportProtocol::Tcp
269 },
270 app_protocol: Some(value.protocol.clone()),
271 http_version: Some(value.http_version),
272 authority: value.authority.clone(),
273 }
274 }
275}
276
277impl<State, Body> TryRefIntoTransportContext<State> for rama_http_types::Request<Body> {
278 type Error = OpaqueError;
279
280 fn try_ref_into_transport_ctx(
281 &self,
282 ctx: &Context<State>,
283 ) -> Result<TransportContext, Self::Error> {
284 (ctx, self).try_into()
285 }
286}
287
288impl<State> TryRefIntoTransportContext<State> for rama_http_types::dep::http::request::Parts {
289 type Error = OpaqueError;
290
291 fn try_ref_into_transport_ctx(
292 &self,
293 ctx: &Context<State>,
294 ) -> Result<TransportContext, Self::Error> {
295 (ctx, self).try_into()
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use crate::forwarded::{Forwarded, ForwardedElement, NodeId};
303 use rama_http_types::header::FORWARDED;
304
305 #[test]
306 fn test_request_context_from_request() {
307 let req = Request::builder()
308 .uri("http://example.com:8080")
309 .version(Version::HTTP_11)
310 .body(())
311 .unwrap();
312
313 let ctx = Context::default();
314
315 let req_ctx = RequestContext::try_from((&ctx, &req)).unwrap();
316
317 assert_eq!(req_ctx.http_version, Version::HTTP_11);
318 assert_eq!(req_ctx.protocol, Protocol::HTTP);
319 assert_eq!(req_ctx.authority.to_string(), "example.com:8080");
320 }
321
322 #[test]
323 fn test_request_context_from_parts() {
324 let req = Request::builder()
325 .uri("http://example.com:8080")
326 .version(Version::HTTP_11)
327 .body(())
328 .unwrap();
329
330 let (parts, _) = req.into_parts();
331
332 let ctx = Context::default();
333 let req_ctx = RequestContext::try_from((&ctx, &parts)).unwrap();
334
335 assert_eq!(req_ctx.http_version, Version::HTTP_11);
336 assert_eq!(req_ctx.protocol, Protocol::HTTP);
337 assert_eq!(
338 req_ctx.authority,
339 Authority::try_from("example.com:8080").unwrap()
340 );
341 }
342
343 #[test]
344 fn test_request_context_authority() {
345 let ctx = RequestContext {
346 http_version: Version::HTTP_11,
347 protocol: Protocol::HTTP,
348 authority: "example.com:8080".try_into().unwrap(),
349 };
350
351 assert_eq!(ctx.authority.to_string(), "example.com:8080");
352 }
353
354 #[test]
355 fn forwarded_parsing() {
356 for (forwarded_str_vec, expected) in [
357 (
359 vec!["host=192.0.2.60;proto=http;by=203.0.113.43"],
360 RequestContext {
361 http_version: Version::HTTP_11,
362 protocol: Protocol::HTTP,
363 authority: "192.0.2.60:80".parse().unwrap(),
364 },
365 ),
366 (
368 vec!["host=\"[2001:db8:cafe::17]:4711\""],
369 RequestContext {
370 http_version: Version::HTTP_11,
371 protocol: Protocol::HTTP,
372 authority: "[2001:db8:cafe::17]:4711".parse().unwrap(),
373 },
374 ),
375 (
377 vec!["host=192.0.2.60, host=127.0.0.1"],
378 RequestContext {
379 http_version: Version::HTTP_11,
380 protocol: Protocol::HTTP,
381 authority: "192.0.2.60:80".parse().unwrap(),
382 },
383 ),
384 (
386 vec!["host=192.0.2.60", "host=127.0.0.1"],
387 RequestContext {
388 http_version: Version::HTTP_11,
389 protocol: Protocol::HTTP,
390 authority: "192.0.2.60:80".parse().unwrap(),
391 },
392 ),
393 ] {
394 let mut req_builder = Request::builder();
395 for header in forwarded_str_vec.clone() {
396 req_builder = req_builder.header(FORWARDED, header);
397 }
398
399 let req = req_builder.body(()).unwrap();
400 let mut ctx = Context::default();
401
402 let forwarded: Forwarded = req.headers().get(FORWARDED).unwrap().try_into().unwrap();
403 ctx.insert(forwarded);
404
405 let req_ctx = ctx
406 .get_or_try_insert_with_ctx::<RequestContext, _>(|ctx| (ctx, &req).try_into())
407 .unwrap()
408 .clone();
409
410 assert_eq!(req_ctx, expected, "Failed for {:?}", forwarded_str_vec);
411 }
412 }
413
414 #[test]
415 fn test_request_ctx_https_request_behind_haproxy_plain() {
416 let req = Request::builder()
417 .uri("/en/reservation/roomdetails")
418 .version(Version::HTTP_11)
419 .header("host", "echo.ramaproxy.org")
420 .header("user-agent", "curl/8.6.0")
421 .header("accept", "*/*")
422 .body(())
423 .unwrap();
424
425 let mut ctx = Context::default();
426 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
427 NodeId::try_from("127.0.0.1:61234").unwrap(),
428 )));
429
430 let req_ctx: &mut RequestContext = ctx
431 .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
432 .unwrap();
433
434 assert_eq!(req_ctx.http_version, Version::HTTP_11);
435 assert_eq!(req_ctx.protocol, "http");
436 assert_eq!(req_ctx.authority.to_string(), "echo.ramaproxy.org:80");
437 }
438
439 #[test]
440 fn test_request_ctx_connect_req_no_scheme() {
441 let test_cases = [
442 (80, Protocol::HTTPS),
443 (433, Protocol::HTTPS),
444 (8080, Protocol::HTTPS),
445 ];
446 for (port, expected_protocol) in test_cases {
447 let req = Request::builder()
448 .uri(format!("www.example.com:{port}"))
449 .version(Version::HTTP_11)
450 .method(Method::CONNECT)
451 .header("host", "www.example.com")
452 .header("user-agent", "test/42")
453 .body(())
454 .unwrap();
455
456 let mut ctx = Context::default();
457 let req_ctx: &mut RequestContext = ctx
458 .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
459 .unwrap();
460
461 assert_eq!(req_ctx.http_version, Version::HTTP_11);
462 assert_eq!(req_ctx.protocol, expected_protocol);
463 assert_eq!(
464 req_ctx.authority.to_string(),
465 format!("www.example.com:{}", port)
466 );
467 }
468 }
469
470 #[test]
471 fn test_request_ctx_connect_req() {
472 let test_cases = [
473 ("http", Protocol::HTTPS),
474 ("https", Protocol::HTTPS),
475 ("ws", Protocol::WSS),
476 ("wss", Protocol::WSS),
477 ("ftp", Protocol::from_static("ftp")),
478 ];
479 for (scheme, expected_protocol) in test_cases {
480 let req = Request::builder()
481 .uri(format!("{scheme}://www.example.com"))
482 .version(Version::HTTP_11)
483 .method(Method::CONNECT)
484 .header("host", "www.example.com")
485 .header("user-agent", "test/42")
486 .body(())
487 .unwrap();
488
489 let mut ctx = Context::default();
490 let req_ctx: &mut RequestContext = ctx
491 .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
492 .unwrap();
493
494 assert_eq!(req_ctx.http_version, Version::HTTP_11);
495 assert_eq!(req_ctx.protocol, expected_protocol);
496 assert_eq!(
497 req_ctx.authority.to_string(),
498 format!(
499 "www.example.com:{}",
500 expected_protocol.default_port().unwrap_or(80)
501 )
502 );
503 }
504 }
505}