reinhardt_http/request.rs
1mod body;
2mod methods;
3mod params;
4
5use crate::extensions::Extensions;
6use bytes::Bytes;
7use hyper::{HeaderMap, Method, Uri, Version};
8#[cfg(feature = "parsers")]
9use reinhardt_core::parsers::parser::{ParsedData, Parser};
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::net::{IpAddr, SocketAddr};
13use std::sync::atomic::AtomicBool;
14use std::sync::{Arc, Mutex};
15
16/// Configuration for trusted proxy IPs.
17///
18/// Only proxy headers (X-Forwarded-For, X-Real-IP, X-Forwarded-Proto) from
19/// these IP addresses will be trusted. By default, no proxies are trusted
20/// and the actual connection information is used.
21#[derive(Debug, Clone, Default, PartialEq, Eq)]
22pub struct TrustedProxies {
23 /// Set of trusted proxy IP addresses.
24 /// Only requests originating from these IPs will have their proxy headers honored.
25 trusted_ips: HashSet<IpAddr>,
26}
27
28impl TrustedProxies {
29 /// Create with no trusted proxies (default, most secure).
30 pub fn none() -> Self {
31 Self {
32 trusted_ips: HashSet::new(),
33 }
34 }
35
36 /// Create with a set of trusted proxy IPs.
37 pub fn new(ips: impl IntoIterator<Item = IpAddr>) -> Self {
38 Self {
39 trusted_ips: ips.into_iter().collect(),
40 }
41 }
42
43 /// Check if the given address is a trusted proxy.
44 pub fn is_trusted(&self, addr: &IpAddr) -> bool {
45 self.trusted_ips.contains(addr)
46 }
47
48 /// Check if any proxies are configured.
49 pub fn has_trusted_proxies(&self) -> bool {
50 !self.trusted_ips.is_empty()
51 }
52}
53
54/// HTTP Request representation
55pub struct Request {
56 /// The HTTP method (GET, POST, PUT, etc.).
57 pub method: Method,
58 /// The request URI (path and query string).
59 pub uri: Uri,
60 /// The HTTP protocol version.
61 pub version: Version,
62 /// The request headers.
63 pub headers: HeaderMap,
64 body: Bytes,
65 /// Path parameters extracted from the URL pattern.
66 pub path_params: HashMap<String, String>,
67 /// Query string parameters parsed from the URI.
68 pub query_params: HashMap<String, String>,
69 /// Indicates if this request came over HTTPS
70 pub is_secure: bool,
71 /// Remote address of the client (if available)
72 pub remote_addr: Option<SocketAddr>,
73 /// Parsers for request body
74 #[cfg(feature = "parsers")]
75 parsers: Vec<Box<dyn Parser>>,
76 /// Cached parsed data (lazy parsing)
77 #[cfg(feature = "parsers")]
78 parsed_data: Arc<Mutex<Option<ParsedData>>>,
79 /// Whether the body has been consumed
80 body_consumed: Arc<AtomicBool>,
81 /// Extensions for storing arbitrary typed data
82 pub extensions: Extensions,
83}
84
85/// Builder for constructing `Request` instances.
86///
87/// Provides a fluent API for building HTTP requests with optional parameters.
88///
89/// # Examples
90///
91/// ```
92/// use reinhardt_http::Request;
93/// use hyper::Method;
94///
95/// let request = Request::builder()
96/// .method(Method::GET)
97/// .uri("/api/users?page=1")
98/// .build()
99/// .unwrap();
100///
101/// assert_eq!(request.method, Method::GET);
102/// assert_eq!(request.path(), "/api/users");
103/// assert_eq!(request.query_params.get("page"), Some(&"1".to_string()));
104/// ```
105pub struct RequestBuilder {
106 method: Method,
107 uri: Option<Uri>,
108 version: Version,
109 headers: HeaderMap,
110 body: Bytes,
111 is_secure: bool,
112 remote_addr: Option<SocketAddr>,
113 path_params: HashMap<String, String>,
114 /// Captured error from invalid URI
115 uri_error: Option<String>,
116 /// Captured error from invalid header value
117 header_error: Option<String>,
118 #[cfg(feature = "parsers")]
119 parsers: Vec<Box<dyn Parser>>,
120}
121
122impl Default for RequestBuilder {
123 fn default() -> Self {
124 Self {
125 method: Method::GET,
126 uri: None,
127 version: Version::HTTP_11,
128 headers: HeaderMap::new(),
129 body: Bytes::new(),
130 is_secure: false,
131 remote_addr: None,
132 path_params: HashMap::new(),
133 uri_error: None,
134 header_error: None,
135 #[cfg(feature = "parsers")]
136 parsers: Vec::new(),
137 }
138 }
139}
140
141impl RequestBuilder {
142 /// Set the HTTP method.
143 ///
144 /// # Examples
145 ///
146 /// ```
147 /// use reinhardt_http::Request;
148 /// use hyper::Method;
149 ///
150 /// let request = Request::builder()
151 /// .method(Method::POST)
152 /// .uri("/api/users")
153 /// .build()
154 /// .unwrap();
155 ///
156 /// assert_eq!(request.method, Method::POST);
157 /// ```
158 pub fn method(mut self, method: Method) -> Self {
159 self.method = method;
160 self
161 }
162
163 /// Set the request URI.
164 ///
165 /// Accepts either a `&str` or `Uri`. Query parameters will be automatically parsed.
166 ///
167 /// # Examples
168 ///
169 /// ```
170 /// use reinhardt_http::Request;
171 /// use hyper::Method;
172 ///
173 /// let request = Request::builder()
174 /// .method(Method::GET)
175 /// .uri("/api/users?page=1&limit=10")
176 /// .build()
177 /// .unwrap();
178 ///
179 /// assert_eq!(request.path(), "/api/users");
180 /// assert_eq!(request.query_params.get("page"), Some(&"1".to_string()));
181 /// assert_eq!(request.query_params.get("limit"), Some(&"10".to_string()));
182 /// ```
183 pub fn uri<T>(mut self, uri: T) -> Self
184 where
185 T: TryInto<Uri>,
186 T::Error: std::fmt::Display,
187 {
188 match uri.try_into() {
189 Ok(uri) => {
190 self.uri = Some(uri);
191 }
192 Err(e) => {
193 self.uri_error = Some(format!("Invalid URI: {}", e));
194 }
195 }
196 self
197 }
198
199 /// Set the HTTP version.
200 ///
201 /// Defaults to HTTP/1.1 if not specified.
202 ///
203 /// # Examples
204 ///
205 /// ```
206 /// use reinhardt_http::Request;
207 /// use hyper::{Method, Version};
208 ///
209 /// let request = Request::builder()
210 /// .method(Method::GET)
211 /// .uri("/api/users")
212 /// .version(Version::HTTP_2)
213 /// .build()
214 /// .unwrap();
215 ///
216 /// assert_eq!(request.version, Version::HTTP_2);
217 /// ```
218 pub fn version(mut self, version: Version) -> Self {
219 self.version = version;
220 self
221 }
222
223 /// Set the request headers.
224 ///
225 /// Replaces all existing headers.
226 ///
227 /// # Examples
228 ///
229 /// ```
230 /// use reinhardt_http::Request;
231 /// use hyper::{Method, HeaderMap, header};
232 ///
233 /// let mut headers = HeaderMap::new();
234 /// headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap());
235 ///
236 /// let request = Request::builder()
237 /// .method(Method::POST)
238 /// .uri("/api/users")
239 /// .headers(headers.clone())
240 /// .build()
241 /// .unwrap();
242 ///
243 /// assert_eq!(request.headers.get(header::CONTENT_TYPE).unwrap(), "application/json");
244 /// ```
245 pub fn headers(mut self, headers: HeaderMap) -> Self {
246 self.headers = headers;
247 self
248 }
249
250 /// Add a single header to the request.
251 ///
252 /// # Examples
253 ///
254 /// ```
255 /// use reinhardt_http::Request;
256 /// use hyper::{Method, header};
257 ///
258 /// let request = Request::builder()
259 /// .method(Method::POST)
260 /// .uri("/api/users")
261 /// .header(header::CONTENT_TYPE, "application/json")
262 /// .header(header::AUTHORIZATION, "Bearer token123")
263 /// .build()
264 /// .unwrap();
265 ///
266 /// assert_eq!(request.headers.get(header::CONTENT_TYPE).unwrap(), "application/json");
267 /// assert_eq!(request.headers.get(header::AUTHORIZATION).unwrap(), "Bearer token123");
268 /// ```
269 pub fn header<K, V>(mut self, key: K, value: V) -> Self
270 where
271 K: hyper::header::IntoHeaderName,
272 V: TryInto<hyper::header::HeaderValue>,
273 V::Error: std::fmt::Display,
274 {
275 match value.try_into() {
276 Ok(val) => {
277 self.headers.insert(key, val);
278 }
279 Err(e) => {
280 self.header_error = Some(format!("Invalid header value: {}", e));
281 }
282 }
283 self
284 }
285
286 /// Set the request body.
287 ///
288 /// # Examples
289 ///
290 /// ```
291 /// use reinhardt_http::Request;
292 /// use hyper::Method;
293 /// use bytes::Bytes;
294 ///
295 /// let body = Bytes::from(r#"{"name":"Alice"}"#);
296 /// let request = Request::builder()
297 /// .method(Method::POST)
298 /// .uri("/api/users")
299 /// .body(body.clone())
300 /// .build()
301 /// .unwrap();
302 ///
303 /// assert_eq!(request.body(), &body);
304 /// ```
305 pub fn body(mut self, body: Bytes) -> Self {
306 self.body = body;
307 self
308 }
309
310 /// Set whether the request is secure (HTTPS).
311 ///
312 /// Defaults to `false` if not specified.
313 ///
314 /// # Examples
315 ///
316 /// ```
317 /// use reinhardt_http::Request;
318 /// use hyper::Method;
319 ///
320 /// let request = Request::builder()
321 /// .method(Method::GET)
322 /// .uri("/")
323 /// .secure(true)
324 /// .build()
325 /// .unwrap();
326 ///
327 /// assert!(request.is_secure());
328 /// assert_eq!(request.scheme(), "https");
329 /// ```
330 pub fn secure(mut self, is_secure: bool) -> Self {
331 self.is_secure = is_secure;
332 self
333 }
334
335 /// Set the remote address of the client.
336 ///
337 /// # Examples
338 ///
339 /// ```
340 /// use reinhardt_http::Request;
341 /// use hyper::Method;
342 /// use std::net::{SocketAddr, IpAddr, Ipv4Addr};
343 ///
344 /// let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
345 /// let request = Request::builder()
346 /// .method(Method::GET)
347 /// .uri("/")
348 /// .remote_addr(addr)
349 /// .build()
350 /// .unwrap();
351 ///
352 /// assert_eq!(request.remote_addr, Some(addr));
353 /// ```
354 pub fn remote_addr(mut self, addr: SocketAddr) -> Self {
355 self.remote_addr = Some(addr);
356 self
357 }
358
359 /// Add a parser to the request.
360 ///
361 /// Parsers are used to parse the request body into specific formats.
362 /// The parser will be boxed internally.
363 ///
364 /// # Examples
365 ///
366 /// ```ignore
367 /// use reinhardt_http::Request;
368 /// use hyper::Method;
369 ///
370 /// let request = Request::builder()
371 /// .method(Method::POST)
372 /// .uri("/api/users")
373 /// .parser(JsonParser::new())
374 /// .build()
375 /// .unwrap();
376 /// ```
377 #[cfg(feature = "parsers")]
378 pub fn parser<P: Parser + 'static>(mut self, parser: P) -> Self {
379 self.parsers.push(Box::new(parser));
380 self
381 }
382
383 /// Set path parameters (used for testing views without router).
384 ///
385 /// This is primarily useful in test environments where you need to simulate
386 /// path parameters that would normally be extracted by the router.
387 ///
388 /// # Examples
389 ///
390 /// ```
391 /// use reinhardt_http::Request;
392 /// use hyper::Method;
393 /// use std::collections::HashMap;
394 ///
395 /// let mut params = HashMap::new();
396 /// params.insert("id".to_string(), "42".to_string());
397 ///
398 /// let request = Request::builder()
399 /// .method(Method::GET)
400 /// .uri("/api/users/42")
401 /// .path_params(params)
402 /// .build()
403 /// .unwrap();
404 ///
405 /// assert_eq!(request.path_params.get("id"), Some(&"42".to_string()));
406 /// ```
407 pub fn path_params(mut self, params: HashMap<String, String>) -> Self {
408 self.path_params = params;
409 self
410 }
411
412 /// Build the final `Request` instance.
413 ///
414 /// Returns an error if the URI is missing.
415 ///
416 /// # Examples
417 ///
418 /// ```
419 /// use reinhardt_http::Request;
420 /// use hyper::Method;
421 ///
422 /// let request = Request::builder()
423 /// .method(Method::GET)
424 /// .uri("/api/users")
425 /// .build()
426 /// .unwrap();
427 ///
428 /// assert_eq!(request.method, Method::GET);
429 /// assert_eq!(request.path(), "/api/users");
430 /// ```
431 pub fn build(self) -> Result<Request, String> {
432 // Report captured errors from builder methods
433 if let Some(err) = self.uri_error {
434 return Err(err);
435 }
436 if let Some(err) = self.header_error {
437 return Err(err);
438 }
439 let uri = self.uri.ok_or_else(|| "URI is required".to_string())?;
440 let query_params = Request::parse_query_params(&uri);
441
442 Ok(Request {
443 method: self.method,
444 uri,
445 version: self.version,
446 headers: self.headers,
447 body: self.body,
448 path_params: self.path_params,
449 query_params,
450 is_secure: self.is_secure,
451 remote_addr: self.remote_addr,
452 #[cfg(feature = "parsers")]
453 parsers: self.parsers,
454 #[cfg(feature = "parsers")]
455 parsed_data: Arc::new(Mutex::new(None)),
456 body_consumed: Arc::new(AtomicBool::new(false)),
457 extensions: Extensions::new(),
458 })
459 }
460}
461
462impl Request {
463 /// Create a new `RequestBuilder`.
464 ///
465 /// # Examples
466 ///
467 /// ```
468 /// use reinhardt_http::Request;
469 /// use hyper::Method;
470 ///
471 /// let request = Request::builder()
472 /// .method(Method::GET)
473 /// .uri("/api/users")
474 /// .build()
475 /// .unwrap();
476 ///
477 /// assert_eq!(request.method, Method::GET);
478 /// ```
479 pub fn builder() -> RequestBuilder {
480 RequestBuilder::default()
481 }
482
483 /// Set the DI context for this request (used by routers with dependency injection)
484 ///
485 /// This method stores the DI context in the request's extensions,
486 /// allowing handlers to access dependency injection services.
487 ///
488 /// The context will be wrapped in an Arc internally for efficient sharing.
489 /// The DI context type is generic to avoid circular dependencies.
490 ///
491 /// # Examples
492 ///
493 /// ```rust,no_run
494 /// use reinhardt_http::Request;
495 /// use hyper::Method;
496 ///
497 /// # struct DummyDiContext;
498 /// let mut request = Request::builder()
499 /// .method(Method::GET)
500 /// .uri("/")
501 /// .build()
502 /// .unwrap();
503 ///
504 /// let di_ctx = DummyDiContext;
505 /// request.set_di_context(di_ctx);
506 /// ```
507 pub fn set_di_context<T: Send + Sync + 'static>(&mut self, ctx: T) {
508 self.extensions.insert(Arc::new(ctx));
509 }
510
511 /// Get the DI context from this request
512 ///
513 /// Returns `None` if no DI context was set.
514 ///
515 /// The DI context type is generic to avoid circular dependencies.
516 /// Returns a reference to the context.
517 ///
518 /// # Examples
519 ///
520 /// ```rust,no_run
521 /// use reinhardt_http::Request;
522 /// use hyper::Method;
523 ///
524 /// # struct DummyDiContext;
525 /// let mut request = Request::builder()
526 /// .method(Method::GET)
527 /// .uri("/")
528 /// .build()
529 /// .unwrap();
530 ///
531 /// let di_ctx = DummyDiContext;
532 /// request.set_di_context(di_ctx);
533 ///
534 /// let ctx = request.get_di_context::<DummyDiContext>();
535 /// assert!(ctx.is_some());
536 /// ```
537 pub fn get_di_context<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
538 self.extensions.get::<Arc<T>>()
539 }
540
541 /// Extract Bearer token from Authorization header
542 ///
543 /// Extracts JWT or other bearer tokens from the Authorization header.
544 /// Returns `None` if the header is missing or not in "Bearer `<token>`" format.
545 ///
546 /// # Examples
547 ///
548 /// ```
549 /// use reinhardt_http::Request;
550 /// use hyper::{Method, Version, HeaderMap, header};
551 /// use bytes::Bytes;
552 ///
553 /// let mut headers = HeaderMap::new();
554 /// headers.insert(
555 /// header::AUTHORIZATION,
556 /// "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9".parse().unwrap()
557 /// );
558 ///
559 /// let request = Request::builder()
560 /// .method(Method::GET)
561 /// .uri("/")
562 /// .version(Version::HTTP_11)
563 /// .headers(headers)
564 /// .body(Bytes::new())
565 /// .build()
566 /// .unwrap();
567 ///
568 /// let token = request.extract_bearer_token();
569 /// assert_eq!(token, Some("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9".to_string()));
570 /// ```
571 ///
572 /// # Missing or invalid header
573 ///
574 /// ```
575 /// use reinhardt_http::Request;
576 /// use hyper::{Method, Version, HeaderMap};
577 /// use bytes::Bytes;
578 ///
579 /// let request = Request::builder()
580 /// .method(Method::GET)
581 /// .uri("/")
582 /// .version(Version::HTTP_11)
583 /// .headers(HeaderMap::new())
584 /// .body(Bytes::new())
585 /// .build()
586 /// .unwrap();
587 ///
588 /// let token = request.extract_bearer_token();
589 /// assert_eq!(token, None);
590 /// ```
591 pub fn extract_bearer_token(&self) -> Option<String> {
592 self.headers
593 .get(hyper::header::AUTHORIZATION)
594 .and_then(|value| value.to_str().ok())
595 .and_then(|auth_str| auth_str.strip_prefix("Bearer ").map(|s| s.to_string()))
596 }
597
598 /// Get a specific header value from the request
599 ///
600 /// Returns `None` if the header is missing or cannot be converted to a string.
601 ///
602 /// # Examples
603 ///
604 /// ```
605 /// use reinhardt_http::Request;
606 /// use hyper::{Method, Version, HeaderMap, header};
607 /// use bytes::Bytes;
608 ///
609 /// let mut headers = HeaderMap::new();
610 /// headers.insert(
611 /// header::USER_AGENT,
612 /// "Mozilla/5.0".parse().unwrap()
613 /// );
614 ///
615 /// let request = Request::builder()
616 /// .method(Method::GET)
617 /// .uri("/")
618 /// .version(Version::HTTP_11)
619 /// .headers(headers)
620 /// .body(Bytes::new())
621 /// .build()
622 /// .unwrap();
623 ///
624 /// let user_agent = request.get_header("user-agent");
625 /// assert_eq!(user_agent, Some("Mozilla/5.0".to_string()));
626 /// ```
627 ///
628 /// # Missing header
629 ///
630 /// ```
631 /// use reinhardt_http::Request;
632 /// use hyper::{Method, Version, HeaderMap};
633 /// use bytes::Bytes;
634 ///
635 /// let request = Request::builder()
636 /// .method(Method::GET)
637 /// .uri("/")
638 /// .version(Version::HTTP_11)
639 /// .headers(HeaderMap::new())
640 /// .body(Bytes::new())
641 /// .build()
642 /// .unwrap();
643 ///
644 /// let header = request.get_header("x-custom-header");
645 /// assert_eq!(header, None);
646 /// ```
647 pub fn get_header(&self, name: &str) -> Option<String> {
648 self.headers
649 .get(name)
650 .and_then(|value| value.to_str().ok())
651 .map(|s| s.to_string())
652 }
653
654 /// Extract client IP address from the request
655 ///
656 /// Only trusts proxy headers (X-Forwarded-For, X-Real-IP) when the request
657 /// originates from a configured trusted proxy. Without trusted proxies,
658 /// falls back to the actual connection address.
659 ///
660 /// # Examples
661 ///
662 /// ```
663 /// use reinhardt_http::{Request, TrustedProxies};
664 /// use hyper::{Method, Version, HeaderMap, header};
665 /// use bytes::Bytes;
666 /// use std::net::{SocketAddr, IpAddr, Ipv4Addr};
667 ///
668 /// let proxy_ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
669 /// let mut headers = HeaderMap::new();
670 /// headers.insert(
671 /// header::HeaderName::from_static("x-forwarded-for"),
672 /// "203.0.113.1, 198.51.100.1".parse().unwrap()
673 /// );
674 ///
675 /// let request = Request::builder()
676 /// .method(Method::GET)
677 /// .uri("/")
678 /// .version(Version::HTTP_11)
679 /// .headers(headers)
680 /// .remote_addr(SocketAddr::new(proxy_ip, 8080))
681 /// .body(Bytes::new())
682 /// .build()
683 /// .unwrap();
684 ///
685 /// // Configure trusted proxies to honor X-Forwarded-For
686 /// request.set_trusted_proxies(TrustedProxies::new(vec![proxy_ip]));
687 ///
688 /// let ip = request.get_client_ip();
689 /// assert_eq!(ip, Some("203.0.113.1".parse().unwrap()));
690 /// ```
691 ///
692 /// # No trusted proxy, fallback to remote_addr
693 ///
694 /// ```
695 /// use reinhardt_http::Request;
696 /// use hyper::{Method, Version, HeaderMap};
697 /// use bytes::Bytes;
698 /// use std::net::{SocketAddr, IpAddr, Ipv4Addr};
699 ///
700 /// let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
701 /// let request = Request::builder()
702 /// .method(Method::GET)
703 /// .uri("/")
704 /// .version(Version::HTTP_11)
705 /// .headers(HeaderMap::new())
706 /// .remote_addr(addr)
707 /// .body(Bytes::new())
708 /// .build()
709 /// .unwrap();
710 ///
711 /// let ip = request.get_client_ip();
712 /// assert_eq!(ip, Some(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))));
713 /// ```
714 pub fn get_client_ip(&self) -> Option<std::net::IpAddr> {
715 // Only trust proxy headers if the request comes from a configured trusted proxy
716 if self.is_from_trusted_proxy() {
717 // Try X-Forwarded-For header first (common in proxy setups)
718 if let Some(forwarded) = self.get_header("x-forwarded-for") {
719 // X-Forwarded-For can contain multiple IPs, take the first one
720 if let Some(first_ip) = forwarded.split(',').next()
721 && let Ok(ip) = first_ip.trim().parse()
722 {
723 return Some(ip);
724 }
725 }
726
727 // Try X-Real-IP header
728 if let Some(real_ip) = self.get_header("x-real-ip")
729 && let Ok(ip) = real_ip.parse()
730 {
731 return Some(ip);
732 }
733 }
734
735 // Fallback to remote_addr (actual connection info)
736 self.remote_addr.map(|addr| addr.ip())
737 }
738
739 /// Check if the request originates from a trusted proxy.
740 ///
741 /// Returns `true` only if [`TrustedProxies`] are configured (via
742 /// [`set_trusted_proxies`](Self::set_trusted_proxies)) **and** the
743 /// remote address of the connection is contained in the trusted set.
744 ///
745 /// # Security
746 ///
747 /// This method gates whether proxy-forwarded headers (e.g.
748 /// `X-Forwarded-For`, `X-Forwarded-Proto`) should be honoured.
749 /// Trusting headers from a non-proxy source allows clients to spoof
750 /// their IP address or protocol, which can bypass IP-based access
751 /// controls and HTTPS enforcement.
752 ///
753 /// **Callers must ensure that [`TrustedProxies`] is configured only
754 /// with IP addresses of reverse proxies actually deployed in front
755 /// of the application.** Misconfiguration (e.g. trusting `0.0.0.0/0`)
756 /// re-introduces header-spoofing vulnerabilities.
757 ///
758 /// # Examples
759 ///
760 /// ```
761 /// use reinhardt_http::Request;
762 /// use reinhardt_http::TrustedProxies;
763 /// use bytes::Bytes;
764 /// use std::net::{IpAddr, Ipv4Addr, SocketAddr};
765 /// use hyper::Method;
766 ///
767 /// let proxy_ip: IpAddr = Ipv4Addr::new(10, 0, 0, 1).into();
768 /// let request = Request::builder()
769 /// .method(Method::GET)
770 /// .uri("/")
771 /// .remote_addr(SocketAddr::new(proxy_ip, 8080))
772 /// .body(Bytes::new())
773 /// .build()
774 /// .unwrap();
775 /// request.set_trusted_proxies(TrustedProxies::new(vec![proxy_ip]));
776 ///
777 /// assert!(request.is_from_trusted_proxy());
778 /// ```
779 pub fn is_from_trusted_proxy(&self) -> bool {
780 if let Some(trusted) = self.extensions.get::<TrustedProxies>()
781 && let Some(addr) = self.remote_addr
782 {
783 return trusted.is_trusted(&addr.ip());
784 }
785 false
786 }
787
788 /// Set trusted proxy configuration for this request.
789 ///
790 /// This is typically called by the server/middleware layer to configure
791 /// which proxy IPs are trusted for header forwarding.
792 pub fn set_trusted_proxies(&self, proxies: TrustedProxies) {
793 self.extensions.insert(proxies);
794 }
795
796 /// Validate Content-Type header
797 ///
798 /// Checks if the Content-Type header matches the expected value.
799 /// Returns an error if the header is missing or doesn't match.
800 ///
801 /// # Examples
802 ///
803 /// ```
804 /// use reinhardt_http::Request;
805 /// use hyper::{Method, Version, HeaderMap, header};
806 /// use bytes::Bytes;
807 ///
808 /// let mut headers = HeaderMap::new();
809 /// headers.insert(
810 /// header::CONTENT_TYPE,
811 /// "application/json".parse().unwrap()
812 /// );
813 ///
814 /// let request = Request::builder()
815 /// .method(Method::POST)
816 /// .uri("/")
817 /// .version(Version::HTTP_11)
818 /// .headers(headers)
819 /// .body(Bytes::new())
820 /// .build()
821 /// .unwrap();
822 ///
823 /// assert!(request.validate_content_type("application/json").is_ok());
824 /// ```
825 ///
826 /// # Content-Type mismatch
827 ///
828 /// ```
829 /// use reinhardt_http::Request;
830 /// use hyper::{Method, Version, HeaderMap, header};
831 /// use bytes::Bytes;
832 ///
833 /// let mut headers = HeaderMap::new();
834 /// headers.insert(
835 /// header::CONTENT_TYPE,
836 /// "text/plain".parse().unwrap()
837 /// );
838 ///
839 /// let request = Request::builder()
840 /// .method(Method::POST)
841 /// .uri("/")
842 /// .version(Version::HTTP_11)
843 /// .headers(headers)
844 /// .body(Bytes::new())
845 /// .build()
846 /// .unwrap();
847 ///
848 /// let result = request.validate_content_type("application/json");
849 /// assert!(result.is_err());
850 /// ```
851 ///
852 /// # Missing Content-Type header
853 ///
854 /// ```
855 /// use reinhardt_http::Request;
856 /// use hyper::{Method, Version, HeaderMap};
857 /// use bytes::Bytes;
858 ///
859 /// let request = Request::builder()
860 /// .method(Method::POST)
861 /// .uri("/")
862 /// .version(Version::HTTP_11)
863 /// .headers(HeaderMap::new())
864 /// .body(Bytes::new())
865 /// .build()
866 /// .unwrap();
867 ///
868 /// let result = request.validate_content_type("application/json");
869 /// assert!(result.is_err());
870 /// ```
871 pub fn validate_content_type(&self, expected: &str) -> crate::Result<()> {
872 match self.get_header("content-type") {
873 Some(content_type) if content_type.starts_with(expected) => Ok(()),
874 Some(content_type) => Err(crate::Error::Http(format!(
875 "Invalid Content-Type: expected '{}', got '{}'",
876 expected, content_type
877 ))),
878 None => Err(crate::Error::Http(
879 "Missing Content-Type header".to_string(),
880 )),
881 }
882 }
883
884 /// Parse query parameters into typed struct
885 ///
886 /// Deserializes query string parameters into the specified type `T`.
887 /// Returns an error if deserialization fails.
888 ///
889 /// # Examples
890 ///
891 /// ```
892 /// use reinhardt_http::Request;
893 /// use hyper::{Method, Version, HeaderMap};
894 /// use bytes::Bytes;
895 /// use serde::Deserialize;
896 ///
897 /// #[derive(Deserialize, Debug, PartialEq)]
898 /// struct Pagination {
899 /// page: u32,
900 /// limit: u32,
901 /// }
902 ///
903 /// let request = Request::builder()
904 /// .method(Method::GET)
905 /// .uri("/api/users?page=2&limit=10")
906 /// .version(Version::HTTP_11)
907 /// .headers(HeaderMap::new())
908 /// .body(Bytes::new())
909 /// .build()
910 /// .unwrap();
911 ///
912 /// let params: Pagination = request.query_as().unwrap();
913 /// assert_eq!(params, Pagination { page: 2, limit: 10 });
914 /// ```
915 ///
916 /// # Type mismatch error
917 ///
918 /// ```
919 /// use reinhardt_http::Request;
920 /// use hyper::{Method, Version, HeaderMap};
921 /// use bytes::Bytes;
922 /// use serde::Deserialize;
923 ///
924 /// #[derive(Deserialize)]
925 /// struct Pagination {
926 /// page: u32,
927 /// limit: u32,
928 /// }
929 ///
930 /// let request = Request::builder()
931 /// .method(Method::GET)
932 /// .uri("/api/users?page=invalid")
933 /// .version(Version::HTTP_11)
934 /// .headers(HeaderMap::new())
935 /// .body(Bytes::new())
936 /// .build()
937 /// .unwrap();
938 ///
939 /// let result: Result<Pagination, _> = request.query_as();
940 /// assert!(result.is_err());
941 /// ```
942 pub fn query_as<T: serde::de::DeserializeOwned>(&self) -> crate::Result<T> {
943 // Convert HashMap<String, String> to Vec<(String, String)> for serde_urlencoded
944 let params: Vec<(String, String)> = self
945 .query_params
946 .iter()
947 .map(|(k, v)| (k.clone(), v.clone()))
948 .collect();
949
950 let encoded = serde_urlencoded::to_string(¶ms)
951 .map_err(|e| crate::Error::Http(format!("Failed to encode query parameters: {}", e)))?;
952 serde_urlencoded::from_str(&encoded)
953 .map_err(|e| crate::Error::Http(format!("Failed to parse query parameters: {}", e)))
954 }
955
956 /// Creates a lightweight copy of this request for dependency injection.
957 ///
958 /// The clone shares the same extensions store (via internal `Arc`),
959 /// so `AuthState` and other extensions set on the original request
960 /// are accessible in the clone. Body and parsers are not copied
961 /// as they are not needed for DI resolution.
962 pub fn clone_for_di(&self) -> Self {
963 Request {
964 method: self.method.clone(),
965 uri: self.uri.clone(),
966 version: self.version,
967 headers: self.headers.clone(),
968 body: Bytes::new(),
969 path_params: self.path_params.clone(),
970 query_params: self.query_params.clone(),
971 is_secure: self.is_secure,
972 remote_addr: self.remote_addr,
973 #[cfg(feature = "parsers")]
974 parsers: Vec::new(),
975 #[cfg(feature = "parsers")]
976 parsed_data: Arc::new(Mutex::new(None)),
977 body_consumed: Arc::new(AtomicBool::new(false)),
978 extensions: self.extensions.clone(),
979 }
980 }
981}
982
983#[cfg(test)]
984mod tests {
985 use super::*;
986 use bytes::Bytes;
987 use hyper::{HeaderMap, Method, Version, header};
988 use rstest::rstest;
989
990 #[rstest]
991 fn test_extract_bearer_token() {
992 let mut headers = HeaderMap::new();
993 headers.insert(
994 header::AUTHORIZATION,
995 "Bearer test_token_123".parse().unwrap(),
996 );
997
998 let request = Request::builder()
999 .method(Method::GET)
1000 .uri("/")
1001 .version(Version::HTTP_11)
1002 .headers(headers)
1003 .body(Bytes::new())
1004 .build()
1005 .unwrap();
1006
1007 let token = request.extract_bearer_token();
1008 assert_eq!(token, Some("test_token_123".to_string()));
1009 }
1010
1011 #[rstest]
1012 fn test_extract_bearer_token_missing() {
1013 let request = Request::builder()
1014 .method(Method::GET)
1015 .uri("/")
1016 .version(Version::HTTP_11)
1017 .headers(HeaderMap::new())
1018 .body(Bytes::new())
1019 .build()
1020 .unwrap();
1021
1022 let token = request.extract_bearer_token();
1023 assert_eq!(token, None);
1024 }
1025
1026 #[rstest]
1027 fn test_get_header() {
1028 let mut headers = HeaderMap::new();
1029 headers.insert(header::USER_AGENT, "TestClient/1.0".parse().unwrap());
1030
1031 let request = Request::builder()
1032 .method(Method::GET)
1033 .uri("/")
1034 .version(Version::HTTP_11)
1035 .headers(headers)
1036 .body(Bytes::new())
1037 .build()
1038 .unwrap();
1039
1040 let user_agent = request.get_header("user-agent");
1041 assert_eq!(user_agent, Some("TestClient/1.0".to_string()));
1042 }
1043
1044 #[rstest]
1045 fn test_get_header_missing() {
1046 let request = Request::builder()
1047 .method(Method::GET)
1048 .uri("/")
1049 .version(Version::HTTP_11)
1050 .headers(HeaderMap::new())
1051 .body(Bytes::new())
1052 .build()
1053 .unwrap();
1054
1055 let header = request.get_header("x-custom-header");
1056 assert_eq!(header, None);
1057 }
1058
1059 #[rstest]
1060 fn test_get_client_ip_forwarded_for_with_trusted_proxy() {
1061 // Arrange
1062 let proxy_ip: std::net::IpAddr = "10.0.0.254".parse().unwrap();
1063 let mut headers = HeaderMap::new();
1064 headers.insert(
1065 header::HeaderName::from_static("x-forwarded-for"),
1066 "192.168.1.1, 10.0.0.1".parse().unwrap(),
1067 );
1068
1069 let request = Request::builder()
1070 .method(Method::GET)
1071 .uri("/")
1072 .version(Version::HTTP_11)
1073 .headers(headers)
1074 .body(Bytes::new())
1075 .remote_addr(std::net::SocketAddr::new(proxy_ip, 8080))
1076 .build()
1077 .unwrap();
1078
1079 // Configure trusted proxies
1080 request.set_trusted_proxies(TrustedProxies::new(vec![proxy_ip]));
1081
1082 // Act & Assert
1083 let ip = request.get_client_ip();
1084 assert_eq!(ip, Some("192.168.1.1".parse().unwrap()));
1085 }
1086
1087 #[rstest]
1088 fn test_get_client_ip_forwarded_for_without_trusted_proxy() {
1089 // Arrange - proxy headers present but no trusted proxy configured
1090 let mut headers = HeaderMap::new();
1091 headers.insert(
1092 header::HeaderName::from_static("x-forwarded-for"),
1093 "192.168.1.1, 10.0.0.1".parse().unwrap(),
1094 );
1095
1096 let remote_ip: std::net::IpAddr = "10.0.0.254".parse().unwrap();
1097 let request = Request::builder()
1098 .method(Method::GET)
1099 .uri("/")
1100 .version(Version::HTTP_11)
1101 .headers(headers)
1102 .body(Bytes::new())
1103 .remote_addr(std::net::SocketAddr::new(remote_ip, 8080))
1104 .build()
1105 .unwrap();
1106
1107 // Act - no trusted proxies, should use remote_addr
1108 let ip = request.get_client_ip();
1109 assert_eq!(ip, Some(remote_ip));
1110 }
1111
1112 #[rstest]
1113 fn test_get_client_ip_real_ip_with_trusted_proxy() {
1114 // Arrange
1115 let proxy_ip: std::net::IpAddr = "10.0.0.254".parse().unwrap();
1116 let mut headers = HeaderMap::new();
1117 headers.insert(
1118 header::HeaderName::from_static("x-real-ip"),
1119 "203.0.113.5".parse().unwrap(),
1120 );
1121
1122 let request = Request::builder()
1123 .method(Method::GET)
1124 .uri("/")
1125 .version(Version::HTTP_11)
1126 .headers(headers)
1127 .body(Bytes::new())
1128 .remote_addr(std::net::SocketAddr::new(proxy_ip, 8080))
1129 .build()
1130 .unwrap();
1131
1132 request.set_trusted_proxies(TrustedProxies::new(vec![proxy_ip]));
1133
1134 // Act & Assert
1135 let ip = request.get_client_ip();
1136 assert_eq!(ip, Some("203.0.113.5".parse().unwrap()));
1137 }
1138
1139 #[rstest]
1140 fn test_get_client_ip_none() {
1141 let request = Request::builder()
1142 .method(Method::GET)
1143 .uri("/")
1144 .version(Version::HTTP_11)
1145 .headers(HeaderMap::new())
1146 .body(Bytes::new())
1147 .build()
1148 .unwrap();
1149
1150 let ip = request.get_client_ip();
1151 assert_eq!(ip, None);
1152 }
1153
1154 #[rstest]
1155 fn test_validate_content_type_valid() {
1156 let mut headers = HeaderMap::new();
1157 headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap());
1158
1159 let request = Request::builder()
1160 .method(Method::POST)
1161 .uri("/")
1162 .version(Version::HTTP_11)
1163 .headers(headers)
1164 .body(Bytes::new())
1165 .build()
1166 .unwrap();
1167
1168 assert!(request.validate_content_type("application/json").is_ok());
1169 }
1170
1171 #[rstest]
1172 fn test_validate_content_type_invalid() {
1173 let mut headers = HeaderMap::new();
1174 headers.insert(header::CONTENT_TYPE, "text/plain".parse().unwrap());
1175
1176 let request = Request::builder()
1177 .method(Method::POST)
1178 .uri("/")
1179 .version(Version::HTTP_11)
1180 .headers(headers)
1181 .body(Bytes::new())
1182 .build()
1183 .unwrap();
1184
1185 assert!(request.validate_content_type("application/json").is_err());
1186 }
1187
1188 #[rstest]
1189 fn test_validate_content_type_missing() {
1190 let request = Request::builder()
1191 .method(Method::POST)
1192 .uri("/")
1193 .version(Version::HTTP_11)
1194 .headers(HeaderMap::new())
1195 .body(Bytes::new())
1196 .build()
1197 .unwrap();
1198
1199 assert!(request.validate_content_type("application/json").is_err());
1200 }
1201
1202 #[rstest]
1203 fn test_clone_for_di_shares_extensions() {
1204 // Arrange
1205 let request = Request::builder()
1206 .method(Method::POST)
1207 .uri("/api/users/42?page=1")
1208 .version(Version::HTTP_11)
1209 .header(header::CONTENT_TYPE, "application/json")
1210 .body(Bytes::from("request body"))
1211 .build()
1212 .unwrap();
1213
1214 request.extensions.insert(42u32);
1215
1216 // Act
1217 let cloned = request.clone_for_di();
1218
1219 // Assert - extensions are shared (same Arc backing store)
1220 assert_eq!(cloned.extensions.get::<u32>(), Some(42));
1221
1222 // Verify metadata is preserved
1223 assert_eq!(cloned.method, Method::POST);
1224 assert_eq!(cloned.uri.path(), "/api/users/42");
1225 assert_eq!(cloned.version, Version::HTTP_11);
1226 assert!(cloned.headers.contains_key(header::CONTENT_TYPE));
1227 assert_eq!(cloned.query_params.get("page"), Some(&"1".to_string()));
1228
1229 // Body should be empty (not needed for DI)
1230 assert!(cloned.body().is_empty());
1231 }
1232
1233 #[rstest]
1234 fn test_clone_for_di_shares_extensions_bidirectionally() {
1235 // Arrange
1236 let request = Request::builder()
1237 .method(Method::GET)
1238 .uri("/")
1239 .build()
1240 .unwrap();
1241
1242 let cloned = request.clone_for_di();
1243
1244 // Act - insert into cloned extensions
1245 cloned.extensions.insert("from_clone".to_string());
1246
1247 // Assert - original also sees it (shared backing store)
1248 assert_eq!(
1249 request.extensions.get::<String>(),
1250 Some("from_clone".to_string())
1251 );
1252 }
1253}