1use crate::forwarded::Forwarded;
2use crate::transport::{TransportContext, TransportProtocol, TryRefIntoTransportContext};
3use crate::{
4 Protocol,
5 address::{Authority, Host},
6};
7use rama_core::Context;
8use rama_core::error::OpaqueError;
9use rama_http_types::Method;
10use rama_http_types::{Request, Uri, Version, dep::http::request::Parts};
11
12#[cfg(feature = "tls")]
13use crate::tls::SecureTransport;
14
15#[cfg(feature = "tls")]
16fn try_get_host_from_secure_transport(t: &SecureTransport) -> Option<Host> {
17 use crate::tls::client::ClientHelloExtension;
18
19 t.client_hello().and_then(|h| {
20 h.extensions().iter().find_map(|e| match e {
21 ClientHelloExtension::ServerName(maybe_host) => maybe_host.clone(),
22 _ => None,
23 })
24 })
25}
26
27#[cfg(not(feature = "tls"))]
28#[derive(Debug, Clone)]
29#[non_exhaustive]
30struct SecureTransport;
31
32#[cfg(not(feature = "tls"))]
33fn try_get_host_from_secure_transport(_: &SecureTransport) -> Option<Host> {
34 None
35}
36
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub struct RequestContext {
40 pub http_version: Version,
42 pub protocol: Protocol,
44 pub authority: Authority,
53}
54
55impl RequestContext {
56 pub fn authority_has_default_port(&self) -> bool {
58 self.protocol.default_port() == Some(self.authority.port())
59 }
60}
61
62impl<Body, State> TryFrom<(&Context<State>, &Request<Body>)> for RequestContext {
63 type Error = OpaqueError;
64
65 fn try_from((ctx, req): (&Context<State>, &Request<Body>)) -> Result<Self, Self::Error> {
66 let uri = req.uri();
67
68 let protocol = protocol_from_uri_or_context(ctx, uri, req.method());
69 tracing::trace!(
70 uri = %uri, "request context: detected protocol: {protocol} (scheme: {:?})",
71 uri.scheme()
72 );
73
74 let default_port = uri
75 .port_u16()
76 .unwrap_or_else(|| protocol.default_port().unwrap_or(80));
77 tracing::trace!(uri = %uri, "request context: detected default port: {default_port}");
78
79 let authority = match ctx.get().and_then(try_get_host_from_secure_transport) {
80 Some(h) => {
81 tracing::trace!(uri = %uri, host = %h, "request context: detected host from SNI");
82 (h, default_port).into()
83 },
84 None => uri
85 .host()
86 .and_then(|h| Host::try_from(h).ok().map(|h| {
87 tracing::trace!(uri = %uri, host = %h, "request context: detected host from (abs) uri");
88 (h, default_port).into()
89 }))
90 .or_else(|| {
91 ctx.get::<Forwarded>().and_then(|f| {
92 f.client_host().map(|fauth| {
93 let (host, port) = fauth.clone().into_parts();
94 let port = port.unwrap_or(default_port);
95 tracing::trace!(uri = %uri, host = %host, "request context: detected host from forwarded info");
96 (host, port).into()
97 })
98 })
99 })
100 .or_else(|| {
101 req.headers()
102 .get(rama_http_types::header::HOST)
103 .and_then(|host| {
104 host.try_into() .or_else(|_| Host::try_from(host).map(|h| {
106 tracing::trace!(uri = %uri, host = %h, "request context: detected host from host header");
107 (h, default_port).into()
108 }))
109 .ok()
110 })
111 })
112 .ok_or_else(|| {
113 OpaqueError::from_display("RequestContext: no authourity found in http::Request")
114 })?
115 };
116
117 tracing::trace!(uri = %uri, "request context: detected authority: {authority}");
118
119 let http_version = ctx
120 .get::<Forwarded>()
121 .and_then(|f| {
122 f.client_version().map(|v| match v {
123 crate::forwarded::ForwardedVersion::HTTP_09 => Version::HTTP_09,
124 crate::forwarded::ForwardedVersion::HTTP_10 => Version::HTTP_10,
125 crate::forwarded::ForwardedVersion::HTTP_11 => Version::HTTP_11,
126 crate::forwarded::ForwardedVersion::HTTP_2 => Version::HTTP_2,
127 crate::forwarded::ForwardedVersion::HTTP_3 => Version::HTTP_3,
128 })
129 })
130 .unwrap_or_else(|| req.version());
131 tracing::trace!(uri = %uri, "request context: maybe detected http version: {http_version:?}");
132
133 Ok(RequestContext {
134 http_version,
135 protocol,
136 authority,
137 })
138 }
139}
140
141impl<State> TryFrom<(&Context<State>, &Parts)> for RequestContext {
142 type Error = OpaqueError;
143
144 fn try_from((ctx, parts): (&Context<State>, &Parts)) -> Result<Self, Self::Error> {
145 let uri = &parts.uri;
146
147 let protocol = protocol_from_uri_or_context(ctx, uri, &parts.method);
148 tracing::trace!(
149 uri = %uri, "request context: detected protocol: {protocol} (scheme: {:?})",
150 uri.scheme()
151 );
152
153 let default_port = uri
154 .port_u16()
155 .unwrap_or_else(|| protocol.default_port().unwrap_or(80));
156 tracing::trace!(uri = %uri, "request context: detected default port: {default_port}");
157
158 let authority = match ctx.get().and_then(try_get_host_from_secure_transport) {
159 Some(h) => {
160 tracing::trace!(uri = %uri, host = %h, "request context: detected host from SNI");
161 (h, default_port).into()
162 }
163 None => {
164 uri
165 .host()
166 .and_then(|h| Host::try_from(h).ok().map(|h| {
167 tracing::trace!(uri = %uri, host = %h, "request context: detected host from (abs) uri");
168 (h, default_port).into()
169 }))
170 .or_else(|| {
171 ctx.get::<Forwarded>().and_then(|f| {
172 f.client_host().map(|fauth| {
173 let (host, port) = fauth.clone().into_parts();
174 let port = port.unwrap_or(default_port);
175 tracing::trace!(uri = %uri, host = %host, "request context: detected host from forwarded info");
176 (host, port).into()
177 })
178 })
179 })
180 .or_else(|| {
181 parts
182 .headers
183 .get(rama_http_types::header::HOST)
184 .and_then(|host| {
185 host.try_into() .or_else(|_| Host::try_from(host).map(|h| {
187 tracing::trace!(uri = %uri, host = %h, "request context: detected host from host header");
188 (h, default_port).into()
189 }))
190 .ok()
191 })
192 })
193 .ok_or_else(|| {
194 OpaqueError::from_display(
195 "RequestContext: no authourity found in http::request::Parts",
196 )
197 })?
198 }
199 };
200
201 tracing::trace!(uri = %uri, "request context: detected authority: {authority}");
202
203 let http_version = ctx
204 .get::<Forwarded>()
205 .and_then(|f| {
206 f.client_version().map(|v| match v {
207 crate::forwarded::ForwardedVersion::HTTP_09 => Version::HTTP_09,
208 crate::forwarded::ForwardedVersion::HTTP_10 => Version::HTTP_10,
209 crate::forwarded::ForwardedVersion::HTTP_11 => Version::HTTP_11,
210 crate::forwarded::ForwardedVersion::HTTP_2 => Version::HTTP_2,
211 crate::forwarded::ForwardedVersion::HTTP_3 => Version::HTTP_3,
212 })
213 })
214 .unwrap_or(parts.version);
215 tracing::trace!(uri = %uri, "request context: maybe detected http version: {http_version:?}");
216
217 Ok(RequestContext {
218 http_version,
219 protocol,
220 authority,
221 })
222 }
223}
224
225#[allow(clippy::unnecessary_lazy_evaluations)]
226fn protocol_from_uri_or_context<State>(
227 ctx: &Context<State>,
228 uri: &Uri,
229 method: &Method,
230) -> Protocol {
231 Protocol::maybe_from_uri_scheme_str_and_method(uri.scheme(), Some(method)).or_else(|| ctx.get::<Forwarded>()
232 .and_then(|f| f.client_proto().map(|p| {
233 tracing::trace!(uri = %uri, "request context: detected protocol from forwarded client proto");
234 p.into()
235 })))
236 .unwrap_or_else(|| {
237 if method == Method::CONNECT {
238 tracing::trace!(uri = %uri, method = %method, "request context: CONNECT: defaulting protocol to HTTPS");
239 Protocol::HTTPS
240 } else {
241 tracing::trace!(uri = %uri, method = %method, "request context: defaulting protocol to HTTP");
242 Protocol::HTTP
243 }
244 })
245}
246
247impl From<RequestContext> for TransportContext {
248 fn from(value: RequestContext) -> Self {
249 Self {
250 protocol: if value.http_version == Version::HTTP_3 {
251 TransportProtocol::Udp
252 } else {
253 TransportProtocol::Tcp
254 },
255 app_protocol: Some(value.protocol),
256 http_version: Some(value.http_version),
257 authority: value.authority,
258 }
259 }
260}
261
262impl From<&RequestContext> for TransportContext {
263 fn from(value: &RequestContext) -> Self {
264 Self {
265 protocol: if value.http_version == Version::HTTP_3 {
266 TransportProtocol::Udp
267 } else {
268 TransportProtocol::Tcp
269 },
270 app_protocol: Some(value.protocol.clone()),
271 http_version: Some(value.http_version),
272 authority: value.authority.clone(),
273 }
274 }
275}
276
277impl<State, Body> TryRefIntoTransportContext<State> for rama_http_types::Request<Body> {
278 type Error = OpaqueError;
279
280 fn try_ref_into_transport_ctx(
281 &self,
282 ctx: &Context<State>,
283 ) -> Result<TransportContext, Self::Error> {
284 (ctx, self).try_into()
285 }
286}
287
288impl<State> TryRefIntoTransportContext<State> for rama_http_types::dep::http::request::Parts {
289 type Error = OpaqueError;
290
291 fn try_ref_into_transport_ctx(
292 &self,
293 ctx: &Context<State>,
294 ) -> Result<TransportContext, Self::Error> {
295 (ctx, self).try_into()
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use crate::forwarded::{Forwarded, ForwardedElement, NodeId};
303 use rama_http_types::header::FORWARDED;
304 use rama_http_types::headers::HeaderMapExt;
305
306 #[test]
307 fn test_request_context_from_request() {
308 let req = Request::builder()
309 .uri("http://example.com:8080")
310 .version(Version::HTTP_11)
311 .body(())
312 .unwrap();
313
314 let ctx = Context::default();
315
316 let req_ctx = RequestContext::try_from((&ctx, &req)).unwrap();
317
318 assert_eq!(req_ctx.http_version, Version::HTTP_11);
319 assert_eq!(req_ctx.protocol, Protocol::HTTP);
320 assert_eq!(req_ctx.authority.to_string(), "example.com:8080");
321 }
322
323 #[test]
324 fn test_request_context_from_parts() {
325 let req = Request::builder()
326 .uri("http://example.com:8080")
327 .version(Version::HTTP_11)
328 .body(())
329 .unwrap();
330
331 let (parts, _) = req.into_parts();
332
333 let ctx = Context::default();
334 let req_ctx = RequestContext::try_from((&ctx, &parts)).unwrap();
335
336 assert_eq!(req_ctx.http_version, Version::HTTP_11);
337 assert_eq!(req_ctx.protocol, Protocol::HTTP);
338 assert_eq!(
339 req_ctx.authority,
340 Authority::try_from("example.com:8080").unwrap()
341 );
342 }
343
344 #[test]
345 fn test_request_context_authority() {
346 let ctx = RequestContext {
347 http_version: Version::HTTP_11,
348 protocol: Protocol::HTTP,
349 authority: "example.com:8080".try_into().unwrap(),
350 };
351
352 assert_eq!(ctx.authority.to_string(), "example.com:8080");
353 }
354
355 #[test]
356 fn forwarded_parsing() {
357 for (forwarded_str_vec, expected) in [
358 (
360 vec!["host=192.0.2.60;proto=http;by=203.0.113.43"],
361 RequestContext {
362 http_version: Version::HTTP_11,
363 protocol: Protocol::HTTP,
364 authority: "192.0.2.60:80".parse().unwrap(),
365 },
366 ),
367 (
369 vec!["host=\"[2001:db8:cafe::17]:4711\""],
370 RequestContext {
371 http_version: Version::HTTP_11,
372 protocol: Protocol::HTTP,
373 authority: "[2001:db8:cafe::17]:4711".parse().unwrap(),
374 },
375 ),
376 (
378 vec!["host=192.0.2.60, host=127.0.0.1"],
379 RequestContext {
380 http_version: Version::HTTP_11,
381 protocol: Protocol::HTTP,
382 authority: "192.0.2.60:80".parse().unwrap(),
383 },
384 ),
385 (
387 vec!["host=192.0.2.60", "host=127.0.0.1"],
388 RequestContext {
389 http_version: Version::HTTP_11,
390 protocol: Protocol::HTTP,
391 authority: "192.0.2.60:80".parse().unwrap(),
392 },
393 ),
394 ] {
395 let mut req_builder = Request::builder();
396 for header in forwarded_str_vec.clone() {
397 req_builder = req_builder.header(FORWARDED, header);
398 }
399
400 let req = req_builder.body(()).unwrap();
401 let mut ctx = Context::default();
402
403 let forwarded = req.headers().typed_get::<Forwarded>().unwrap();
404 ctx.insert(forwarded);
405
406 let req_ctx = ctx
407 .get_or_try_insert_with_ctx::<RequestContext, _>(|ctx| (ctx, &req).try_into())
408 .unwrap()
409 .clone();
410
411 assert_eq!(req_ctx, expected, "Failed for {:?}", forwarded_str_vec);
412 }
413 }
414
415 #[test]
416 fn test_request_ctx_https_request_behind_haproxy_plain() {
417 let req = Request::builder()
418 .uri("/en/reservation/roomdetails")
419 .version(Version::HTTP_11)
420 .header("host", "echo.ramaproxy.org")
421 .header("user-agent", "curl/8.6.0")
422 .header("accept", "*/*")
423 .body(())
424 .unwrap();
425
426 let mut ctx = Context::default();
427 ctx.insert(Forwarded::new(ForwardedElement::forwarded_for(
428 NodeId::try_from("127.0.0.1:61234").unwrap(),
429 )));
430
431 let req_ctx: &mut RequestContext = ctx
432 .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
433 .unwrap();
434
435 assert_eq!(req_ctx.http_version, Version::HTTP_11);
436 assert_eq!(req_ctx.protocol, "http");
437 assert_eq!(req_ctx.authority.to_string(), "echo.ramaproxy.org:80");
438 }
439
440 #[test]
441 fn test_request_ctx_connect_req_no_scheme() {
442 let test_cases = [
443 (80, Protocol::HTTPS),
444 (433, Protocol::HTTPS),
445 (8080, Protocol::HTTPS),
446 ];
447 for (port, expected_protocol) in test_cases {
448 let req = Request::builder()
449 .uri(format!("www.example.com:{port}"))
450 .version(Version::HTTP_11)
451 .method(Method::CONNECT)
452 .header("host", "www.example.com")
453 .header("user-agent", "test/42")
454 .body(())
455 .unwrap();
456
457 let mut ctx = Context::default();
458 let req_ctx: &mut RequestContext = ctx
459 .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
460 .unwrap();
461
462 assert_eq!(req_ctx.http_version, Version::HTTP_11);
463 assert_eq!(req_ctx.protocol, expected_protocol);
464 assert_eq!(
465 req_ctx.authority.to_string(),
466 format!("www.example.com:{}", port)
467 );
468 }
469 }
470
471 #[test]
472 fn test_request_ctx_connect_req() {
473 let test_cases = [
474 ("http", Protocol::HTTPS),
475 ("https", Protocol::HTTPS),
476 ("ws", Protocol::WSS),
477 ("wss", Protocol::WSS),
478 ("ftp", Protocol::from_static("ftp")),
479 ];
480 for (scheme, expected_protocol) in test_cases {
481 let req = Request::builder()
482 .uri(format!("{scheme}://www.example.com"))
483 .version(Version::HTTP_11)
484 .method(Method::CONNECT)
485 .header("host", "www.example.com")
486 .header("user-agent", "test/42")
487 .body(())
488 .unwrap();
489
490 let mut ctx = Context::default();
491 let req_ctx: &mut RequestContext = ctx
492 .get_or_try_insert_with_ctx(|ctx| (ctx, &req).try_into())
493 .unwrap();
494
495 assert_eq!(req_ctx.http_version, Version::HTTP_11);
496 assert_eq!(req_ctx.protocol, expected_protocol);
497 assert_eq!(
498 req_ctx.authority.to_string(),
499 format!(
500 "www.example.com:{}",
501 expected_protocol.default_port().unwrap_or(80)
502 )
503 );
504 }
505 }
506}