reinhardt_http/request.rs
1mod body;
2mod methods;
3mod params;
4
5use crate::extensions::Extensions;
6use crate::path_params::PathParams;
7use bytes::Bytes;
8use hyper::{HeaderMap, Method, Uri, Version};
9#[cfg(feature = "parsers")]
10use reinhardt_core::parsers::parser::{ParsedData, Parser};
11use std::collections::HashMap;
12use std::collections::HashSet;
13use std::net::{IpAddr, SocketAddr};
14use std::sync::atomic::AtomicBool;
15use std::sync::{Arc, Mutex};
16
17/// Configuration for trusted proxy IPs.
18///
19/// Only proxy headers (X-Forwarded-For, X-Real-IP, X-Forwarded-Proto) from
20/// these IP addresses will be trusted. By default, no proxies are trusted
21/// and the actual connection information is used.
22#[derive(Debug, Clone, Default, PartialEq, Eq)]
23pub struct TrustedProxies {
24 /// Set of trusted proxy IP addresses.
25 /// Only requests originating from these IPs will have their proxy headers honored.
26 trusted_ips: HashSet<IpAddr>,
27}
28
29impl TrustedProxies {
30 /// Create with no trusted proxies (default, most secure).
31 pub fn none() -> Self {
32 Self {
33 trusted_ips: HashSet::new(),
34 }
35 }
36
37 /// Create with a set of trusted proxy IPs.
38 pub fn new(ips: impl IntoIterator<Item = IpAddr>) -> Self {
39 Self {
40 trusted_ips: ips.into_iter().collect(),
41 }
42 }
43
44 /// Check if the given address is a trusted proxy.
45 pub fn is_trusted(&self, addr: &IpAddr) -> bool {
46 self.trusted_ips.contains(addr)
47 }
48
49 /// Check if any proxies are configured.
50 pub fn has_trusted_proxies(&self) -> bool {
51 !self.trusted_ips.is_empty()
52 }
53}
54
55/// HTTP Request representation
56pub struct Request {
57 /// The HTTP method (GET, POST, PUT, etc.).
58 pub method: Method,
59 /// The request URI (path and query string).
60 pub uri: Uri,
61 /// The HTTP protocol version.
62 pub version: Version,
63 /// The request headers.
64 pub headers: HeaderMap,
65 body: Bytes,
66 /// Path parameters extracted from the URL pattern.
67 ///
68 /// Stored in URL pattern declaration order (see [`PathParams`]).
69 pub path_params: PathParams,
70 /// Query string parameters parsed from the URI.
71 pub query_params: HashMap<String, String>,
72 /// Indicates if this request came over HTTPS
73 pub is_secure: bool,
74 /// Remote address of the client (if available)
75 pub remote_addr: Option<SocketAddr>,
76 /// Parsers for request body
77 #[cfg(feature = "parsers")]
78 parsers: Vec<Box<dyn Parser>>,
79 /// Cached parsed data (lazy parsing)
80 #[cfg(feature = "parsers")]
81 parsed_data: Arc<Mutex<Option<ParsedData>>>,
82 /// Whether the body has been consumed
83 body_consumed: Arc<AtomicBool>,
84 /// Extensions for storing arbitrary typed data
85 pub extensions: Extensions,
86}
87
88/// Builder for constructing `Request` instances.
89///
90/// Provides a fluent API for building HTTP requests with optional parameters.
91///
92/// # Examples
93///
94/// ```
95/// use reinhardt_http::Request;
96/// use hyper::Method;
97///
98/// let request = Request::builder()
99/// .method(Method::GET)
100/// .uri("/api/users?page=1")
101/// .build()
102/// .unwrap();
103///
104/// assert_eq!(request.method, Method::GET);
105/// assert_eq!(request.path(), "/api/users");
106/// assert_eq!(request.query_params.get("page"), Some(&"1".to_string()));
107/// ```
108pub struct RequestBuilder {
109 method: Method,
110 uri: Option<Uri>,
111 version: Version,
112 headers: HeaderMap,
113 body: Bytes,
114 is_secure: bool,
115 remote_addr: Option<SocketAddr>,
116 path_params: PathParams,
117 /// Captured error from invalid URI
118 uri_error: Option<String>,
119 /// Captured error from invalid header value
120 header_error: Option<String>,
121 #[cfg(feature = "parsers")]
122 parsers: Vec<Box<dyn Parser>>,
123}
124
125impl Default for RequestBuilder {
126 fn default() -> Self {
127 Self {
128 method: Method::GET,
129 uri: None,
130 version: Version::HTTP_11,
131 headers: HeaderMap::new(),
132 body: Bytes::new(),
133 is_secure: false,
134 remote_addr: None,
135 path_params: PathParams::new(),
136 uri_error: None,
137 header_error: None,
138 #[cfg(feature = "parsers")]
139 parsers: Vec::new(),
140 }
141 }
142}
143
144impl RequestBuilder {
145 /// Set the HTTP method.
146 ///
147 /// # Examples
148 ///
149 /// ```
150 /// use reinhardt_http::Request;
151 /// use hyper::Method;
152 ///
153 /// let request = Request::builder()
154 /// .method(Method::POST)
155 /// .uri("/api/users")
156 /// .build()
157 /// .unwrap();
158 ///
159 /// assert_eq!(request.method, Method::POST);
160 /// ```
161 pub fn method(mut self, method: Method) -> Self {
162 self.method = method;
163 self
164 }
165
166 /// Set the request URI.
167 ///
168 /// Accepts either a `&str` or `Uri`. Query parameters will be automatically parsed.
169 ///
170 /// # Examples
171 ///
172 /// ```
173 /// use reinhardt_http::Request;
174 /// use hyper::Method;
175 ///
176 /// let request = Request::builder()
177 /// .method(Method::GET)
178 /// .uri("/api/users?page=1&limit=10")
179 /// .build()
180 /// .unwrap();
181 ///
182 /// assert_eq!(request.path(), "/api/users");
183 /// assert_eq!(request.query_params.get("page"), Some(&"1".to_string()));
184 /// assert_eq!(request.query_params.get("limit"), Some(&"10".to_string()));
185 /// ```
186 pub fn uri<T>(mut self, uri: T) -> Self
187 where
188 T: TryInto<Uri>,
189 T::Error: std::fmt::Display,
190 {
191 match uri.try_into() {
192 Ok(uri) => {
193 self.uri = Some(uri);
194 }
195 Err(e) => {
196 self.uri_error = Some(format!("Invalid URI: {}", e));
197 }
198 }
199 self
200 }
201
202 /// Set the HTTP version.
203 ///
204 /// Defaults to HTTP/1.1 if not specified.
205 ///
206 /// # Examples
207 ///
208 /// ```
209 /// use reinhardt_http::Request;
210 /// use hyper::{Method, Version};
211 ///
212 /// let request = Request::builder()
213 /// .method(Method::GET)
214 /// .uri("/api/users")
215 /// .version(Version::HTTP_2)
216 /// .build()
217 /// .unwrap();
218 ///
219 /// assert_eq!(request.version, Version::HTTP_2);
220 /// ```
221 pub fn version(mut self, version: Version) -> Self {
222 self.version = version;
223 self
224 }
225
226 /// Set the request headers.
227 ///
228 /// Replaces all existing headers.
229 ///
230 /// # Examples
231 ///
232 /// ```
233 /// use reinhardt_http::Request;
234 /// use hyper::{Method, HeaderMap, header};
235 ///
236 /// let mut headers = HeaderMap::new();
237 /// headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap());
238 ///
239 /// let request = Request::builder()
240 /// .method(Method::POST)
241 /// .uri("/api/users")
242 /// .headers(headers.clone())
243 /// .build()
244 /// .unwrap();
245 ///
246 /// assert_eq!(request.headers.get(header::CONTENT_TYPE).unwrap(), "application/json");
247 /// ```
248 pub fn headers(mut self, headers: HeaderMap) -> Self {
249 self.headers = headers;
250 self
251 }
252
253 /// Add a single header to the request.
254 ///
255 /// # Examples
256 ///
257 /// ```
258 /// use reinhardt_http::Request;
259 /// use hyper::{Method, header};
260 ///
261 /// let request = Request::builder()
262 /// .method(Method::POST)
263 /// .uri("/api/users")
264 /// .header(header::CONTENT_TYPE, "application/json")
265 /// .header(header::AUTHORIZATION, "Bearer token123")
266 /// .build()
267 /// .unwrap();
268 ///
269 /// assert_eq!(request.headers.get(header::CONTENT_TYPE).unwrap(), "application/json");
270 /// assert_eq!(request.headers.get(header::AUTHORIZATION).unwrap(), "Bearer token123");
271 /// ```
272 pub fn header<K, V>(mut self, key: K, value: V) -> Self
273 where
274 K: hyper::header::IntoHeaderName,
275 V: TryInto<hyper::header::HeaderValue>,
276 V::Error: std::fmt::Display,
277 {
278 match value.try_into() {
279 Ok(val) => {
280 self.headers.insert(key, val);
281 }
282 Err(e) => {
283 self.header_error = Some(format!("Invalid header value: {}", e));
284 }
285 }
286 self
287 }
288
289 /// Set the request body.
290 ///
291 /// # Examples
292 ///
293 /// ```
294 /// use reinhardt_http::Request;
295 /// use hyper::Method;
296 /// use bytes::Bytes;
297 ///
298 /// let body = Bytes::from(r#"{"name":"Alice"}"#);
299 /// let request = Request::builder()
300 /// .method(Method::POST)
301 /// .uri("/api/users")
302 /// .body(body.clone())
303 /// .build()
304 /// .unwrap();
305 ///
306 /// assert_eq!(request.body(), &body);
307 /// ```
308 pub fn body(mut self, body: Bytes) -> Self {
309 self.body = body;
310 self
311 }
312
313 /// Set whether the request is secure (HTTPS).
314 ///
315 /// Defaults to `false` if not specified.
316 ///
317 /// # Examples
318 ///
319 /// ```
320 /// use reinhardt_http::Request;
321 /// use hyper::Method;
322 ///
323 /// let request = Request::builder()
324 /// .method(Method::GET)
325 /// .uri("/")
326 /// .secure(true)
327 /// .build()
328 /// .unwrap();
329 ///
330 /// assert!(request.is_secure());
331 /// assert_eq!(request.scheme(), "https");
332 /// ```
333 pub fn secure(mut self, is_secure: bool) -> Self {
334 self.is_secure = is_secure;
335 self
336 }
337
338 /// Set the remote address of the client.
339 ///
340 /// # Examples
341 ///
342 /// ```
343 /// use reinhardt_http::Request;
344 /// use hyper::Method;
345 /// use std::net::{SocketAddr, IpAddr, Ipv4Addr};
346 ///
347 /// let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
348 /// let request = Request::builder()
349 /// .method(Method::GET)
350 /// .uri("/")
351 /// .remote_addr(addr)
352 /// .build()
353 /// .unwrap();
354 ///
355 /// assert_eq!(request.remote_addr, Some(addr));
356 /// ```
357 pub fn remote_addr(mut self, addr: SocketAddr) -> Self {
358 self.remote_addr = Some(addr);
359 self
360 }
361
362 /// Add a parser to the request.
363 ///
364 /// Parsers are used to parse the request body into specific formats.
365 /// The parser will be boxed internally.
366 ///
367 /// # Examples
368 ///
369 /// ```ignore
370 /// use reinhardt_http::Request;
371 /// use hyper::Method;
372 ///
373 /// let request = Request::builder()
374 /// .method(Method::POST)
375 /// .uri("/api/users")
376 /// .parser(JsonParser::new())
377 /// .build()
378 /// .unwrap();
379 /// ```
380 #[cfg(feature = "parsers")]
381 pub fn parser<P: Parser + 'static>(mut self, parser: P) -> Self {
382 self.parsers.push(Box::new(parser));
383 self
384 }
385
386 /// Set path parameters (used for testing views without router).
387 ///
388 /// This is primarily useful in test environments where you need to simulate
389 /// path parameters that would normally be extracted by the router. Accepts
390 /// any value that can be converted into [`PathParams`], including a
391 /// `HashMap<String, String>` (note: converting from a `HashMap` does not
392 /// preserve ordering — pass a `Vec<(String, String)>` or [`PathParams`]
393 /// directly when ordering matters).
394 ///
395 /// # Examples
396 ///
397 /// ```
398 /// use reinhardt_http::Request;
399 /// use hyper::Method;
400 /// use std::collections::HashMap;
401 ///
402 /// let mut params = HashMap::new();
403 /// params.insert("id".to_string(), "42".to_string());
404 ///
405 /// let request = Request::builder()
406 /// .method(Method::GET)
407 /// .uri("/api/users/42")
408 /// .path_params(params)
409 /// .build()
410 /// .unwrap();
411 ///
412 /// assert_eq!(request.path_params.get("id"), Some(&"42".to_string()));
413 /// ```
414 pub fn path_params(mut self, params: impl Into<PathParams>) -> Self {
415 self.path_params = params.into();
416 self
417 }
418
419 /// Build the final `Request` instance.
420 ///
421 /// Returns an error if the URI is missing.
422 ///
423 /// # Examples
424 ///
425 /// ```
426 /// use reinhardt_http::Request;
427 /// use hyper::Method;
428 ///
429 /// let request = Request::builder()
430 /// .method(Method::GET)
431 /// .uri("/api/users")
432 /// .build()
433 /// .unwrap();
434 ///
435 /// assert_eq!(request.method, Method::GET);
436 /// assert_eq!(request.path(), "/api/users");
437 /// ```
438 pub fn build(self) -> Result<Request, String> {
439 // Report captured errors from builder methods
440 if let Some(err) = self.uri_error {
441 return Err(err);
442 }
443 if let Some(err) = self.header_error {
444 return Err(err);
445 }
446 let uri = self.uri.ok_or_else(|| "URI is required".to_string())?;
447 let query_params = Request::parse_query_params(&uri);
448
449 Ok(Request {
450 method: self.method,
451 uri,
452 version: self.version,
453 headers: self.headers,
454 body: self.body,
455 path_params: self.path_params,
456 query_params,
457 is_secure: self.is_secure,
458 remote_addr: self.remote_addr,
459 #[cfg(feature = "parsers")]
460 parsers: self.parsers,
461 #[cfg(feature = "parsers")]
462 parsed_data: Arc::new(Mutex::new(None)),
463 body_consumed: Arc::new(AtomicBool::new(false)),
464 extensions: Extensions::new(),
465 })
466 }
467}
468
469impl Request {
470 /// Create a new `RequestBuilder`.
471 ///
472 /// # Examples
473 ///
474 /// ```
475 /// use reinhardt_http::Request;
476 /// use hyper::Method;
477 ///
478 /// let request = Request::builder()
479 /// .method(Method::GET)
480 /// .uri("/api/users")
481 /// .build()
482 /// .unwrap();
483 ///
484 /// assert_eq!(request.method, Method::GET);
485 /// ```
486 pub fn builder() -> RequestBuilder {
487 RequestBuilder::default()
488 }
489
490 /// Set the DI context for this request (used by routers with dependency injection)
491 ///
492 /// This method stores the DI context in the request's extensions,
493 /// allowing handlers to access dependency injection services.
494 ///
495 /// The context will be wrapped in an Arc internally for efficient sharing.
496 /// The DI context type is generic to avoid circular dependencies.
497 ///
498 /// # Examples
499 ///
500 /// ```rust,no_run
501 /// use reinhardt_http::Request;
502 /// use hyper::Method;
503 ///
504 /// # struct DummyDiContext;
505 /// let mut request = Request::builder()
506 /// .method(Method::GET)
507 /// .uri("/")
508 /// .build()
509 /// .unwrap();
510 ///
511 /// let di_ctx = DummyDiContext;
512 /// request.set_di_context(di_ctx);
513 /// ```
514 pub fn set_di_context<T: Send + Sync + 'static>(&mut self, ctx: T) {
515 self.extensions.insert(Arc::new(ctx));
516 }
517
518 /// Get the DI context from this request
519 ///
520 /// Returns `None` if no DI context was set.
521 ///
522 /// The DI context type is generic to avoid circular dependencies.
523 /// Returns a reference to the context.
524 ///
525 /// # Examples
526 ///
527 /// ```rust,no_run
528 /// use reinhardt_http::Request;
529 /// use hyper::Method;
530 ///
531 /// # struct DummyDiContext;
532 /// let mut request = Request::builder()
533 /// .method(Method::GET)
534 /// .uri("/")
535 /// .build()
536 /// .unwrap();
537 ///
538 /// let di_ctx = DummyDiContext;
539 /// request.set_di_context(di_ctx);
540 ///
541 /// let ctx = request.get_di_context::<DummyDiContext>();
542 /// assert!(ctx.is_some());
543 /// ```
544 pub fn get_di_context<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
545 self.extensions.get::<Arc<T>>()
546 }
547
548 /// Extract Bearer token from Authorization header
549 ///
550 /// Extracts JWT or other bearer tokens from the Authorization header.
551 /// Returns `None` if the header is missing or not in "Bearer `<token>`" format.
552 ///
553 /// # Examples
554 ///
555 /// ```
556 /// use reinhardt_http::Request;
557 /// use hyper::{Method, Version, HeaderMap, header};
558 /// use bytes::Bytes;
559 ///
560 /// let mut headers = HeaderMap::new();
561 /// headers.insert(
562 /// header::AUTHORIZATION,
563 /// "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9".parse().unwrap()
564 /// );
565 ///
566 /// let request = Request::builder()
567 /// .method(Method::GET)
568 /// .uri("/")
569 /// .version(Version::HTTP_11)
570 /// .headers(headers)
571 /// .body(Bytes::new())
572 /// .build()
573 /// .unwrap();
574 ///
575 /// let token = request.extract_bearer_token();
576 /// assert_eq!(token, Some("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9".to_string()));
577 /// ```
578 ///
579 /// # Missing or invalid header
580 ///
581 /// ```
582 /// use reinhardt_http::Request;
583 /// use hyper::{Method, Version, HeaderMap};
584 /// use bytes::Bytes;
585 ///
586 /// let request = Request::builder()
587 /// .method(Method::GET)
588 /// .uri("/")
589 /// .version(Version::HTTP_11)
590 /// .headers(HeaderMap::new())
591 /// .body(Bytes::new())
592 /// .build()
593 /// .unwrap();
594 ///
595 /// let token = request.extract_bearer_token();
596 /// assert_eq!(token, None);
597 /// ```
598 pub fn extract_bearer_token(&self) -> Option<String> {
599 self.headers
600 .get(hyper::header::AUTHORIZATION)
601 .and_then(|value| value.to_str().ok())
602 .and_then(|auth_str| auth_str.strip_prefix("Bearer ").map(|s| s.to_string()))
603 }
604
605 /// Get a specific header value from the request
606 ///
607 /// Returns `None` if the header is missing or cannot be converted to a string.
608 ///
609 /// # Examples
610 ///
611 /// ```
612 /// use reinhardt_http::Request;
613 /// use hyper::{Method, Version, HeaderMap, header};
614 /// use bytes::Bytes;
615 ///
616 /// let mut headers = HeaderMap::new();
617 /// headers.insert(
618 /// header::USER_AGENT,
619 /// "Mozilla/5.0".parse().unwrap()
620 /// );
621 ///
622 /// let request = Request::builder()
623 /// .method(Method::GET)
624 /// .uri("/")
625 /// .version(Version::HTTP_11)
626 /// .headers(headers)
627 /// .body(Bytes::new())
628 /// .build()
629 /// .unwrap();
630 ///
631 /// let user_agent = request.get_header("user-agent");
632 /// assert_eq!(user_agent, Some("Mozilla/5.0".to_string()));
633 /// ```
634 ///
635 /// # Missing header
636 ///
637 /// ```
638 /// use reinhardt_http::Request;
639 /// use hyper::{Method, Version, HeaderMap};
640 /// use bytes::Bytes;
641 ///
642 /// let request = Request::builder()
643 /// .method(Method::GET)
644 /// .uri("/")
645 /// .version(Version::HTTP_11)
646 /// .headers(HeaderMap::new())
647 /// .body(Bytes::new())
648 /// .build()
649 /// .unwrap();
650 ///
651 /// let header = request.get_header("x-custom-header");
652 /// assert_eq!(header, None);
653 /// ```
654 pub fn get_header(&self, name: &str) -> Option<String> {
655 self.headers
656 .get(name)
657 .and_then(|value| value.to_str().ok())
658 .map(|s| s.to_string())
659 }
660
661 /// Extract client IP address from the request
662 ///
663 /// Only trusts proxy headers (X-Forwarded-For, X-Real-IP) when the request
664 /// originates from a configured trusted proxy. Without trusted proxies,
665 /// falls back to the actual connection address.
666 ///
667 /// # Examples
668 ///
669 /// ```
670 /// use reinhardt_http::{Request, TrustedProxies};
671 /// use hyper::{Method, Version, HeaderMap, header};
672 /// use bytes::Bytes;
673 /// use std::net::{SocketAddr, IpAddr, Ipv4Addr};
674 ///
675 /// let proxy_ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
676 /// let mut headers = HeaderMap::new();
677 /// headers.insert(
678 /// header::HeaderName::from_static("x-forwarded-for"),
679 /// "203.0.113.1, 198.51.100.1".parse().unwrap()
680 /// );
681 ///
682 /// let request = Request::builder()
683 /// .method(Method::GET)
684 /// .uri("/")
685 /// .version(Version::HTTP_11)
686 /// .headers(headers)
687 /// .remote_addr(SocketAddr::new(proxy_ip, 8080))
688 /// .body(Bytes::new())
689 /// .build()
690 /// .unwrap();
691 ///
692 /// // Configure trusted proxies to honor X-Forwarded-For
693 /// request.set_trusted_proxies(TrustedProxies::new(vec![proxy_ip]));
694 ///
695 /// let ip = request.get_client_ip();
696 /// assert_eq!(ip, Some("203.0.113.1".parse().unwrap()));
697 /// ```
698 ///
699 /// # No trusted proxy, fallback to remote_addr
700 ///
701 /// ```
702 /// use reinhardt_http::Request;
703 /// use hyper::{Method, Version, HeaderMap};
704 /// use bytes::Bytes;
705 /// use std::net::{SocketAddr, IpAddr, Ipv4Addr};
706 ///
707 /// let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
708 /// let request = Request::builder()
709 /// .method(Method::GET)
710 /// .uri("/")
711 /// .version(Version::HTTP_11)
712 /// .headers(HeaderMap::new())
713 /// .remote_addr(addr)
714 /// .body(Bytes::new())
715 /// .build()
716 /// .unwrap();
717 ///
718 /// let ip = request.get_client_ip();
719 /// assert_eq!(ip, Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))));
720 /// ```
721 pub fn get_client_ip(&self) -> Option<std::net::IpAddr> {
722 // Only trust proxy headers if the request comes from a configured trusted proxy
723 if self.is_from_trusted_proxy() {
724 // Try X-Forwarded-For header first (common in proxy setups)
725 if let Some(forwarded) = self.get_header("x-forwarded-for") {
726 // X-Forwarded-For can contain multiple IPs, take the first one
727 if let Some(first_ip) = forwarded.split(',').next()
728 && let Ok(ip) = first_ip.trim().parse()
729 {
730 return Some(ip);
731 }
732 }
733
734 // Try X-Real-IP header
735 if let Some(real_ip) = self.get_header("x-real-ip")
736 && let Ok(ip) = real_ip.parse()
737 {
738 return Some(ip);
739 }
740 }
741
742 // Fallback to remote_addr (actual connection info)
743 self.remote_addr.map(|addr| addr.ip())
744 }
745
746 /// Check if the request originates from a trusted proxy.
747 ///
748 /// Returns `true` only if [`TrustedProxies`] are configured (via
749 /// [`set_trusted_proxies`](Self::set_trusted_proxies)) **and** the
750 /// remote address of the connection is contained in the trusted set.
751 ///
752 /// # Security
753 ///
754 /// This method gates whether proxy-forwarded headers (e.g.
755 /// `X-Forwarded-For`, `X-Forwarded-Proto`) should be honoured.
756 /// Trusting headers from a non-proxy source allows clients to spoof
757 /// their IP address or protocol, which can bypass IP-based access
758 /// controls and HTTPS enforcement.
759 ///
760 /// **Callers must ensure that [`TrustedProxies`] is configured only
761 /// with IP addresses of reverse proxies actually deployed in front
762 /// of the application.** Misconfiguration (e.g. trusting `0.0.0.0/0`)
763 /// re-introduces header-spoofing vulnerabilities.
764 ///
765 /// # Examples
766 ///
767 /// ```
768 /// use reinhardt_http::Request;
769 /// use reinhardt_http::TrustedProxies;
770 /// use bytes::Bytes;
771 /// use std::net::{IpAddr, Ipv4Addr, SocketAddr};
772 /// use hyper::Method;
773 ///
774 /// let proxy_ip: IpAddr = Ipv4Addr::new(10, 0, 0, 1).into();
775 /// let request = Request::builder()
776 /// .method(Method::GET)
777 /// .uri("/")
778 /// .remote_addr(SocketAddr::new(proxy_ip, 8080))
779 /// .body(Bytes::new())
780 /// .build()
781 /// .unwrap();
782 /// request.set_trusted_proxies(TrustedProxies::new(vec![proxy_ip]));
783 ///
784 /// assert!(request.is_from_trusted_proxy());
785 /// ```
786 pub fn is_from_trusted_proxy(&self) -> bool {
787 if let Some(trusted) = self.extensions.get::<TrustedProxies>()
788 && let Some(addr) = self.remote_addr
789 {
790 return trusted.is_trusted(&addr.ip());
791 }
792 false
793 }
794
795 /// Set trusted proxy configuration for this request.
796 ///
797 /// This is typically called by the server/middleware layer to configure
798 /// which proxy IPs are trusted for header forwarding.
799 pub fn set_trusted_proxies(&self, proxies: TrustedProxies) {
800 self.extensions.insert(proxies);
801 }
802
803 /// Validate Content-Type header
804 ///
805 /// Checks if the Content-Type header matches the expected value.
806 /// Returns an error if the header is missing or doesn't match.
807 ///
808 /// # Examples
809 ///
810 /// ```
811 /// use reinhardt_http::Request;
812 /// use hyper::{Method, Version, HeaderMap, header};
813 /// use bytes::Bytes;
814 ///
815 /// let mut headers = HeaderMap::new();
816 /// headers.insert(
817 /// header::CONTENT_TYPE,
818 /// "application/json".parse().unwrap()
819 /// );
820 ///
821 /// let request = Request::builder()
822 /// .method(Method::POST)
823 /// .uri("/")
824 /// .version(Version::HTTP_11)
825 /// .headers(headers)
826 /// .body(Bytes::new())
827 /// .build()
828 /// .unwrap();
829 ///
830 /// assert!(request.validate_content_type("application/json").is_ok());
831 /// ```
832 ///
833 /// # Content-Type mismatch
834 ///
835 /// ```
836 /// use reinhardt_http::Request;
837 /// use hyper::{Method, Version, HeaderMap, header};
838 /// use bytes::Bytes;
839 ///
840 /// let mut headers = HeaderMap::new();
841 /// headers.insert(
842 /// header::CONTENT_TYPE,
843 /// "text/plain".parse().unwrap()
844 /// );
845 ///
846 /// let request = Request::builder()
847 /// .method(Method::POST)
848 /// .uri("/")
849 /// .version(Version::HTTP_11)
850 /// .headers(headers)
851 /// .body(Bytes::new())
852 /// .build()
853 /// .unwrap();
854 ///
855 /// let result = request.validate_content_type("application/json");
856 /// assert!(result.is_err());
857 /// ```
858 ///
859 /// # Missing Content-Type header
860 ///
861 /// ```
862 /// use reinhardt_http::Request;
863 /// use hyper::{Method, Version, HeaderMap};
864 /// use bytes::Bytes;
865 ///
866 /// let request = Request::builder()
867 /// .method(Method::POST)
868 /// .uri("/")
869 /// .version(Version::HTTP_11)
870 /// .headers(HeaderMap::new())
871 /// .body(Bytes::new())
872 /// .build()
873 /// .unwrap();
874 ///
875 /// let result = request.validate_content_type("application/json");
876 /// assert!(result.is_err());
877 /// ```
878 pub fn validate_content_type(&self, expected: &str) -> crate::Result<()> {
879 match self.get_header("content-type") {
880 Some(content_type) if content_type.starts_with(expected) => Ok(()),
881 Some(content_type) => Err(crate::Error::Http(format!(
882 "Invalid Content-Type: expected '{}', got '{}'",
883 expected, content_type
884 ))),
885 None => Err(crate::Error::Http(
886 "Missing Content-Type header".to_string(),
887 )),
888 }
889 }
890
891 /// Parse query parameters into typed struct
892 ///
893 /// Deserializes query string parameters into the specified type `T`.
894 /// Returns an error if deserialization fails.
895 ///
896 /// # Examples
897 ///
898 /// ```
899 /// use reinhardt_http::Request;
900 /// use hyper::{Method, Version, HeaderMap};
901 /// use bytes::Bytes;
902 /// use serde::Deserialize;
903 ///
904 /// #[derive(Deserialize, Debug, PartialEq)]
905 /// struct Pagination {
906 /// page: u32,
907 /// limit: u32,
908 /// }
909 ///
910 /// let request = Request::builder()
911 /// .method(Method::GET)
912 /// .uri("/api/users?page=2&limit=10")
913 /// .version(Version::HTTP_11)
914 /// .headers(HeaderMap::new())
915 /// .body(Bytes::new())
916 /// .build()
917 /// .unwrap();
918 ///
919 /// let params: Pagination = request.query_as().unwrap();
920 /// assert_eq!(params, Pagination { page: 2, limit: 10 });
921 /// ```
922 ///
923 /// # Type mismatch error
924 ///
925 /// ```
926 /// use reinhardt_http::Request;
927 /// use hyper::{Method, Version, HeaderMap};
928 /// use bytes::Bytes;
929 /// use serde::Deserialize;
930 ///
931 /// #[derive(Deserialize)]
932 /// struct Pagination {
933 /// page: u32,
934 /// limit: u32,
935 /// }
936 ///
937 /// let request = Request::builder()
938 /// .method(Method::GET)
939 /// .uri("/api/users?page=invalid")
940 /// .version(Version::HTTP_11)
941 /// .headers(HeaderMap::new())
942 /// .body(Bytes::new())
943 /// .build()
944 /// .unwrap();
945 ///
946 /// let result: Result<Pagination, _> = request.query_as();
947 /// assert!(result.is_err());
948 /// ```
949 pub fn query_as<T: serde::de::DeserializeOwned>(&self) -> crate::Result<T> {
950 // Convert HashMap<String, String> to Vec<(String, String)> for serde_urlencoded
951 let params: Vec<(String, String)> = self
952 .query_params
953 .iter()
954 .map(|(k, v)| (k.clone(), v.clone()))
955 .collect();
956
957 let encoded = serde_urlencoded::to_string(¶ms)
958 .map_err(|e| crate::Error::Http(format!("Failed to encode query parameters: {}", e)))?;
959 serde_urlencoded::from_str(&encoded)
960 .map_err(|e| crate::Error::Http(format!("Failed to parse query parameters: {}", e)))
961 }
962
963 /// Creates a lightweight copy of this request for dependency injection.
964 ///
965 /// The clone shares the same extensions store (via internal `Arc`),
966 /// so `AuthState` and other extensions set on the original request
967 /// are accessible in the clone. Body and parsers are not copied
968 /// as they are not needed for DI resolution.
969 pub fn clone_for_di(&self) -> Self {
970 Request {
971 method: self.method.clone(),
972 uri: self.uri.clone(),
973 version: self.version,
974 headers: self.headers.clone(),
975 body: Bytes::new(),
976 path_params: self.path_params.clone(),
977 query_params: self.query_params.clone(),
978 is_secure: self.is_secure,
979 remote_addr: self.remote_addr,
980 #[cfg(feature = "parsers")]
981 parsers: Vec::new(),
982 #[cfg(feature = "parsers")]
983 parsed_data: Arc::new(Mutex::new(None)),
984 body_consumed: Arc::new(AtomicBool::new(false)),
985 extensions: self.extensions.clone(),
986 }
987 }
988}
989
990#[cfg(test)]
991mod tests {
992 use super::*;
993 use bytes::Bytes;
994 use hyper::{HeaderMap, Method, Version, header};
995 use rstest::rstest;
996
997 #[rstest]
998 fn test_extract_bearer_token() {
999 let mut headers = HeaderMap::new();
1000 headers.insert(
1001 header::AUTHORIZATION,
1002 "Bearer test_token_123".parse().unwrap(),
1003 );
1004
1005 let request = Request::builder()
1006 .method(Method::GET)
1007 .uri("/")
1008 .version(Version::HTTP_11)
1009 .headers(headers)
1010 .body(Bytes::new())
1011 .build()
1012 .unwrap();
1013
1014 let token = request.extract_bearer_token();
1015 assert_eq!(token, Some("test_token_123".to_string()));
1016 }
1017
1018 #[rstest]
1019 fn test_extract_bearer_token_missing() {
1020 let request = Request::builder()
1021 .method(Method::GET)
1022 .uri("/")
1023 .version(Version::HTTP_11)
1024 .headers(HeaderMap::new())
1025 .body(Bytes::new())
1026 .build()
1027 .unwrap();
1028
1029 let token = request.extract_bearer_token();
1030 assert_eq!(token, None);
1031 }
1032
1033 #[rstest]
1034 fn test_get_header() {
1035 let mut headers = HeaderMap::new();
1036 headers.insert(header::USER_AGENT, "TestClient/1.0".parse().unwrap());
1037
1038 let request = Request::builder()
1039 .method(Method::GET)
1040 .uri("/")
1041 .version(Version::HTTP_11)
1042 .headers(headers)
1043 .body(Bytes::new())
1044 .build()
1045 .unwrap();
1046
1047 let user_agent = request.get_header("user-agent");
1048 assert_eq!(user_agent, Some("TestClient/1.0".to_string()));
1049 }
1050
1051 #[rstest]
1052 fn test_get_header_missing() {
1053 let request = Request::builder()
1054 .method(Method::GET)
1055 .uri("/")
1056 .version(Version::HTTP_11)
1057 .headers(HeaderMap::new())
1058 .body(Bytes::new())
1059 .build()
1060 .unwrap();
1061
1062 let header = request.get_header("x-custom-header");
1063 assert_eq!(header, None);
1064 }
1065
1066 #[rstest]
1067 fn test_get_client_ip_forwarded_for_with_trusted_proxy() {
1068 // Arrange
1069 let proxy_ip: std::net::IpAddr = "10.0.0.254".parse().unwrap();
1070 let mut headers = HeaderMap::new();
1071 headers.insert(
1072 header::HeaderName::from_static("x-forwarded-for"),
1073 "192.168.1.1, 10.0.0.1".parse().unwrap(),
1074 );
1075
1076 let request = Request::builder()
1077 .method(Method::GET)
1078 .uri("/")
1079 .version(Version::HTTP_11)
1080 .headers(headers)
1081 .body(Bytes::new())
1082 .remote_addr(std::net::SocketAddr::new(proxy_ip, 8080))
1083 .build()
1084 .unwrap();
1085
1086 // Configure trusted proxies
1087 request.set_trusted_proxies(TrustedProxies::new(vec![proxy_ip]));
1088
1089 // Act & Assert
1090 let ip = request.get_client_ip();
1091 assert_eq!(ip, Some("192.168.1.1".parse().unwrap()));
1092 }
1093
1094 #[rstest]
1095 fn test_get_client_ip_forwarded_for_without_trusted_proxy() {
1096 // Arrange - proxy headers present but no trusted proxy configured
1097 let mut headers = HeaderMap::new();
1098 headers.insert(
1099 header::HeaderName::from_static("x-forwarded-for"),
1100 "192.168.1.1, 10.0.0.1".parse().unwrap(),
1101 );
1102
1103 let remote_ip: std::net::IpAddr = "10.0.0.254".parse().unwrap();
1104 let request = Request::builder()
1105 .method(Method::GET)
1106 .uri("/")
1107 .version(Version::HTTP_11)
1108 .headers(headers)
1109 .body(Bytes::new())
1110 .remote_addr(std::net::SocketAddr::new(remote_ip, 8080))
1111 .build()
1112 .unwrap();
1113
1114 // Act - no trusted proxies, should use remote_addr
1115 let ip = request.get_client_ip();
1116 assert_eq!(ip, Some(remote_ip));
1117 }
1118
1119 #[rstest]
1120 fn test_get_client_ip_real_ip_with_trusted_proxy() {
1121 // Arrange
1122 let proxy_ip: std::net::IpAddr = "10.0.0.254".parse().unwrap();
1123 let mut headers = HeaderMap::new();
1124 headers.insert(
1125 header::HeaderName::from_static("x-real-ip"),
1126 "203.0.113.5".parse().unwrap(),
1127 );
1128
1129 let request = Request::builder()
1130 .method(Method::GET)
1131 .uri("/")
1132 .version(Version::HTTP_11)
1133 .headers(headers)
1134 .body(Bytes::new())
1135 .remote_addr(std::net::SocketAddr::new(proxy_ip, 8080))
1136 .build()
1137 .unwrap();
1138
1139 request.set_trusted_proxies(TrustedProxies::new(vec![proxy_ip]));
1140
1141 // Act & Assert
1142 let ip = request.get_client_ip();
1143 assert_eq!(ip, Some("203.0.113.5".parse().unwrap()));
1144 }
1145
1146 #[rstest]
1147 fn test_get_client_ip_none() {
1148 let request = Request::builder()
1149 .method(Method::GET)
1150 .uri("/")
1151 .version(Version::HTTP_11)
1152 .headers(HeaderMap::new())
1153 .body(Bytes::new())
1154 .build()
1155 .unwrap();
1156
1157 let ip = request.get_client_ip();
1158 assert_eq!(ip, None);
1159 }
1160
1161 #[rstest]
1162 fn test_validate_content_type_valid() {
1163 let mut headers = HeaderMap::new();
1164 headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap());
1165
1166 let request = Request::builder()
1167 .method(Method::POST)
1168 .uri("/")
1169 .version(Version::HTTP_11)
1170 .headers(headers)
1171 .body(Bytes::new())
1172 .build()
1173 .unwrap();
1174
1175 assert!(request.validate_content_type("application/json").is_ok());
1176 }
1177
1178 #[rstest]
1179 fn test_validate_content_type_invalid() {
1180 let mut headers = HeaderMap::new();
1181 headers.insert(header::CONTENT_TYPE, "text/plain".parse().unwrap());
1182
1183 let request = Request::builder()
1184 .method(Method::POST)
1185 .uri("/")
1186 .version(Version::HTTP_11)
1187 .headers(headers)
1188 .body(Bytes::new())
1189 .build()
1190 .unwrap();
1191
1192 assert!(request.validate_content_type("application/json").is_err());
1193 }
1194
1195 #[rstest]
1196 fn test_validate_content_type_missing() {
1197 let request = Request::builder()
1198 .method(Method::POST)
1199 .uri("/")
1200 .version(Version::HTTP_11)
1201 .headers(HeaderMap::new())
1202 .body(Bytes::new())
1203 .build()
1204 .unwrap();
1205
1206 assert!(request.validate_content_type("application/json").is_err());
1207 }
1208
1209 #[rstest]
1210 fn test_clone_for_di_shares_extensions() {
1211 // Arrange
1212 let request = Request::builder()
1213 .method(Method::POST)
1214 .uri("/api/users/42?page=1")
1215 .version(Version::HTTP_11)
1216 .header(header::CONTENT_TYPE, "application/json")
1217 .body(Bytes::from("request body"))
1218 .build()
1219 .unwrap();
1220
1221 request.extensions.insert(42u32);
1222
1223 // Act
1224 let cloned = request.clone_for_di();
1225
1226 // Assert - extensions are shared (same Arc backing store)
1227 assert_eq!(cloned.extensions.get::<u32>(), Some(42));
1228
1229 // Verify metadata is preserved
1230 assert_eq!(cloned.method, Method::POST);
1231 assert_eq!(cloned.uri.path(), "/api/users/42");
1232 assert_eq!(cloned.version, Version::HTTP_11);
1233 assert!(cloned.headers.contains_key(header::CONTENT_TYPE));
1234 assert_eq!(cloned.query_params.get("page"), Some(&"1".to_string()));
1235
1236 // Body should be empty (not needed for DI)
1237 assert!(cloned.body().is_empty());
1238 }
1239
1240 #[rstest]
1241 fn test_clone_for_di_shares_extensions_bidirectionally() {
1242 // Arrange
1243 let request = Request::builder()
1244 .method(Method::GET)
1245 .uri("/")
1246 .build()
1247 .unwrap();
1248
1249 let cloned = request.clone_for_di();
1250
1251 // Act - insert into cloned extensions
1252 cloned.extensions.insert("from_clone".to_string());
1253
1254 // Assert - original also sees it (shared backing store)
1255 assert_eq!(
1256 request.extensions.get::<String>(),
1257 Some("from_clone".to_string())
1258 );
1259 }
1260}