1use crate::{Error, HttpRequest};
38use serde::de::DeserializeOwned;
39use std::ops::Deref;
40use std::sync::Arc;
41
42pub trait FromRequest: Sized {
44 fn from_request(request: &HttpRequest) -> Result<Self, Error>;
46}
47
48#[derive(Debug)]
92pub struct State<T: Send + Sync + 'static>(pub Arc<T>);
93
94impl<T: Send + Sync + 'static> State<T> {
95 #[inline]
97 pub fn new(value: Arc<T>) -> Self {
98 Self(value)
99 }
100
101 #[inline]
103 pub fn into_inner(self) -> Arc<T> {
104 self.0
105 }
106}
107
108impl<T: Send + Sync + 'static> Clone for State<T> {
109 #[inline]
110 fn clone(&self) -> Self {
111 Self(Arc::clone(&self.0))
112 }
113}
114
115impl<T: Send + Sync + 'static> Deref for State<T> {
116 type Target = T;
117
118 #[inline]
119 fn deref(&self) -> &Self::Target {
120 &self.0
121 }
122}
123
124impl<T: Send + Sync + 'static> AsRef<T> for State<T> {
125 #[inline]
126 fn as_ref(&self) -> &T {
127 &self.0
128 }
129}
130
131impl<T: Send + Sync + 'static> FromRequest for State<T> {
132 #[inline]
139 fn from_request(request: &HttpRequest) -> Result<Self, Error> {
140 request.extensions.get_arc::<T>().map(State).ok_or_else(|| {
141 Error::ProviderNotFound(format!(
142 "State<{}> not found in request extensions. \
143 Did you forget to register it with `app.with_state()`?",
144 std::any::type_name::<T>()
145 ))
146 })
147 }
148}
149
150pub trait FromRequestNamed: Sized {
152 fn from_request(request: &HttpRequest, name: &str) -> Result<Self, Error>;
154}
155
156#[derive(Debug, Clone)]
173pub struct Body<T>(pub T);
174
175impl<T> Body<T> {
176 pub fn new(value: T) -> Self {
178 Self(value)
179 }
180
181 pub fn into_inner(self) -> T {
183 self.0
184 }
185}
186
187impl<T> Deref for Body<T> {
188 type Target = T;
189
190 fn deref(&self) -> &Self::Target {
191 &self.0
192 }
193}
194
195impl<T: DeserializeOwned> FromRequest for Body<T> {
196 fn from_request(request: &HttpRequest) -> Result<Self, Error> {
197 let value: T = request.json()?;
198 Ok(Body(value))
199 }
200}
201
202#[derive(Debug, Clone)]
220pub struct Query<T>(pub T);
221
222impl<T> Query<T> {
223 pub fn new(value: T) -> Self {
225 Self(value)
226 }
227
228 pub fn into_inner(self) -> T {
230 self.0
231 }
232}
233
234impl<T> Deref for Query<T> {
235 type Target = T;
236
237 fn deref(&self) -> &Self::Target {
238 &self.0
239 }
240}
241
242impl<T: DeserializeOwned> FromRequest for Query<T> {
243 fn from_request(request: &HttpRequest) -> Result<Self, Error> {
244 let query_string: String = request
246 .query_params
247 .iter()
248 .map(|(k, v)| format!("{}={}", k, v))
249 .collect::<Vec<_>>()
250 .join("&");
251
252 let value: T = serde_urlencoded::from_str(&query_string)
253 .map_err(|e| Error::Validation(format!("Invalid query parameters: {}", e)))?;
254
255 Ok(Query(value))
256 }
257}
258
259#[derive(Debug, Clone)]
271pub struct Path<T>(pub T);
272
273impl<T> Path<T> {
274 pub fn new(value: T) -> Self {
276 Self(value)
277 }
278
279 pub fn into_inner(self) -> T {
281 self.0
282 }
283}
284
285impl<T> Deref for Path<T> {
286 type Target = T;
287
288 fn deref(&self) -> &Self::Target {
289 &self.0
290 }
291}
292
293impl<T: std::str::FromStr> FromRequestNamed for Path<T>
294where
295 T::Err: std::fmt::Display,
296{
297 fn from_request(request: &HttpRequest, name: &str) -> Result<Self, Error> {
298 let value_str = request
299 .param(name)
300 .ok_or_else(|| Error::Validation(format!("Missing path parameter: {}", name)))?;
301
302 let value: T = value_str.parse().map_err(|e: T::Err| {
303 Error::Validation(format!("Invalid path parameter '{}': {}", name, e))
304 })?;
305
306 Ok(Path(value))
307 }
308}
309
310#[derive(Debug, Clone)]
327pub struct PathParams<T>(pub T);
328
329impl<T> PathParams<T> {
330 pub fn new(value: T) -> Self {
332 Self(value)
333 }
334
335 pub fn into_inner(self) -> T {
337 self.0
338 }
339}
340
341impl<T> Deref for PathParams<T> {
342 type Target = T;
343
344 fn deref(&self) -> &Self::Target {
345 &self.0
346 }
347}
348
349impl<T: DeserializeOwned> FromRequest for PathParams<T> {
350 fn from_request(request: &HttpRequest) -> Result<Self, Error> {
351 let params_string: String = request
353 .path_params
354 .iter()
355 .map(|(k, v)| format!("{}={}", k, v))
356 .collect::<Vec<_>>()
357 .join("&");
358
359 let value: T = serde_urlencoded::from_str(¶ms_string)
360 .map_err(|e| Error::Validation(format!("Invalid path parameters: {}", e)))?;
361
362 Ok(PathParams(value))
363 }
364}
365
366#[derive(Debug, Clone)]
380pub struct Header {
381 name: String,
382 value: String,
383}
384
385impl Header {
386 pub fn new(name: impl Into<String>, value: impl Into<String>) -> Self {
388 Self {
389 name: name.into(),
390 value: value.into(),
391 }
392 }
393
394 pub fn name(&self) -> &str {
396 &self.name
397 }
398
399 pub fn value(&self) -> &str {
401 &self.value
402 }
403
404 pub fn into_value(self) -> String {
406 self.value
407 }
408
409 pub fn optional(request: &HttpRequest, name: &str) -> Option<Self> {
411 request
412 .headers
413 .get(name)
414 .or_else(|| request.headers.get(&name.to_lowercase()))
415 .map(|v| Header::new(name, v.clone()))
416 }
417}
418
419impl FromRequestNamed for Header {
420 fn from_request(request: &HttpRequest, name: &str) -> Result<Self, Error> {
421 let value = request
422 .headers
423 .get(name)
424 .or_else(|| request.headers.get(&name.to_lowercase()))
425 .ok_or_else(|| Error::Validation(format!("Missing header: {}", name)))?;
426
427 Ok(Header::new(name, value.clone()))
428 }
429}
430
431impl Deref for Header {
432 type Target = str;
433
434 fn deref(&self) -> &Self::Target {
435 &self.value
436 }
437}
438
439#[derive(Debug, Clone)]
443pub struct Headers(pub std::collections::HashMap<String, String>);
444
445impl Headers {
446 pub fn get(&self, name: &str) -> Option<&String> {
448 self.0
449 .get(name)
450 .or_else(|| self.0.get(&name.to_lowercase()))
451 }
452
453 pub fn contains(&self, name: &str) -> bool {
455 self.get(name).is_some()
456 }
457
458 pub fn iter(&self) -> impl Iterator<Item = (&String, &String)> {
460 self.0.iter()
461 }
462}
463
464impl FromRequest for Headers {
465 fn from_request(request: &HttpRequest) -> Result<Self, Error> {
466 Ok(Headers(request.headers.clone()))
467 }
468}
469
470impl Deref for Headers {
471 type Target = std::collections::HashMap<String, String>;
472
473 fn deref(&self) -> &Self::Target {
474 &self.0
475 }
476}
477
478#[derive(Debug, Clone)]
489pub struct RawBody(pub Vec<u8>);
490
491impl RawBody {
492 pub fn new(data: Vec<u8>) -> Self {
494 Self(data)
495 }
496
497 pub fn len(&self) -> usize {
499 self.0.len()
500 }
501
502 pub fn is_empty(&self) -> bool {
504 self.0.is_empty()
505 }
506
507 pub fn to_string_lossy(&self) -> String {
509 String::from_utf8_lossy(&self.0).to_string()
510 }
511
512 pub fn to_string(&self) -> Result<String, std::string::FromUtf8Error> {
514 String::from_utf8(self.0.clone())
515 }
516
517 pub fn into_inner(self) -> Vec<u8> {
519 self.0
520 }
521}
522
523impl FromRequest for RawBody {
524 fn from_request(request: &HttpRequest) -> Result<Self, Error> {
525 Ok(RawBody(request.body.clone()))
526 }
527}
528
529impl Deref for RawBody {
530 type Target = [u8];
531
532 fn deref(&self) -> &Self::Target {
533 &self.0
534 }
535}
536
537#[derive(Debug, Clone)]
553pub struct Form<T>(pub T);
554
555impl<T> Form<T> {
556 pub fn new(value: T) -> Self {
558 Self(value)
559 }
560
561 pub fn into_inner(self) -> T {
563 self.0
564 }
565}
566
567impl<T> Deref for Form<T> {
568 type Target = T;
569
570 fn deref(&self) -> &Self::Target {
571 &self.0
572 }
573}
574
575impl<T: DeserializeOwned> FromRequest for Form<T> {
576 fn from_request(request: &HttpRequest) -> Result<Self, Error> {
577 let value: T = request.form()?;
578 Ok(Form(value))
579 }
580}
581
582#[derive(Debug, Clone)]
586pub struct ContentType(pub String);
587
588impl ContentType {
589 pub fn is_json(&self) -> bool {
591 self.0.contains("application/json")
592 }
593
594 pub fn is_form(&self) -> bool {
596 self.0.contains("application/x-www-form-urlencoded")
597 }
598
599 pub fn is_multipart(&self) -> bool {
601 self.0.contains("multipart/form-data")
602 }
603
604 pub fn into_inner(self) -> String {
606 self.0
607 }
608}
609
610impl FromRequest for ContentType {
611 fn from_request(request: &HttpRequest) -> Result<Self, Error> {
612 let value = request
613 .headers
614 .get("Content-Type")
615 .or_else(|| request.headers.get("content-type"))
616 .cloned()
617 .unwrap_or_default();
618
619 Ok(ContentType(value))
620 }
621}
622
623impl Deref for ContentType {
624 type Target = str;
625
626 fn deref(&self) -> &Self::Target {
627 &self.0
628 }
629}
630
631#[derive(Debug, Clone)]
635pub struct Method(pub String);
636
637impl Method {
638 pub fn is_get(&self) -> bool {
640 self.0 == "GET"
641 }
642
643 pub fn is_post(&self) -> bool {
645 self.0 == "POST"
646 }
647
648 pub fn is_put(&self) -> bool {
650 self.0 == "PUT"
651 }
652
653 pub fn is_delete(&self) -> bool {
655 self.0 == "DELETE"
656 }
657
658 pub fn is_patch(&self) -> bool {
660 self.0 == "PATCH"
661 }
662}
663
664impl FromRequest for Method {
665 fn from_request(request: &HttpRequest) -> Result<Self, Error> {
666 Ok(Method(request.method.clone()))
667 }
668}
669
670impl Deref for Method {
671 type Target = str;
672
673 fn deref(&self) -> &Self::Target {
674 &self.0
675 }
676}
677
678impl FromRequest for HttpRequest {
681 fn from_request(request: &HttpRequest) -> Result<Self, Error> {
682 Ok(request.clone())
683 }
684}
685
686#[macro_export]
698macro_rules! body {
699 ($request:expr, $type:ty) => {
700 <$crate::extractors::Body<$type> as $crate::extractors::FromRequest>::from_request(
701 &$request,
702 )
703 .map(|b| b.into_inner())
704 };
705}
706
707#[macro_export]
715macro_rules! query {
716 ($request:expr, $type:ty) => {
717 <$crate::extractors::Query<$type> as $crate::extractors::FromRequest>::from_request(
718 &$request,
719 )
720 .map(|q| q.into_inner())
721 };
722}
723
724#[macro_export]
732macro_rules! path {
733 ($request:expr, $name:expr, $type:ty) => {
734 <$crate::extractors::Path<$type> as $crate::extractors::FromRequestNamed>::from_request(
735 &$request, $name,
736 )
737 .map(|p| p.into_inner())
738 };
739}
740
741#[macro_export]
749macro_rules! header {
750 ($request:expr, $name:expr) => {
751 <$crate::extractors::Header as $crate::extractors::FromRequestNamed>::from_request(
752 &$request, $name,
753 )
754 .map(|h| h.into_value())
755 };
756}
757
758#[cfg(test)]
759mod tests {
760 use super::*;
761 use serde::Deserialize;
762
763 fn create_request() -> HttpRequest {
764 let mut req = HttpRequest::new("GET".to_string(), "/users/123".to_string());
765 req.path_params.insert("id".to_string(), "123".to_string());
766 req.query_params.insert("page".to_string(), "1".to_string());
767 req.query_params
768 .insert("limit".to_string(), "10".to_string());
769 req.headers
770 .insert("Authorization".to_string(), "Bearer token123".to_string());
771 req.headers
772 .insert("Content-Type".to_string(), "application/json".to_string());
773 req
774 }
775
776 #[test]
777 fn test_path_extraction() {
778 let request = create_request();
779 let id: Path<u32> = Path::from_request(&request, "id").unwrap();
780 assert_eq!(*id, 123);
781 }
782
783 #[test]
784 fn test_path_missing() {
785 let request = create_request();
786 let result: Result<Path<u32>, _> = Path::from_request(&request, "missing");
787 assert!(result.is_err());
788 }
789
790 #[test]
791 fn test_header_extraction() {
792 let request = create_request();
793 let auth: Header = Header::from_request(&request, "Authorization").unwrap();
794 assert_eq!(auth.value(), "Bearer token123");
795 }
796
797 #[test]
798 fn test_header_optional() {
799 let request = create_request();
800
801 let auth = Header::optional(&request, "Authorization");
802 assert!(auth.is_some());
803
804 let missing = Header::optional(&request, "X-Missing");
805 assert!(missing.is_none());
806 }
807
808 #[test]
809 fn test_headers_extraction() {
810 let request = create_request();
811 let headers: Headers = Headers::from_request(&request).unwrap();
812
813 assert!(headers.contains("Authorization"));
814 assert!(headers.contains("Content-Type"));
815 assert!(!headers.contains("X-Missing"));
816 }
817
818 #[test]
819 fn test_query_extraction() {
820 let request = create_request();
821
822 #[derive(Debug, Deserialize, PartialEq)]
823 struct Pagination {
824 page: u32,
825 limit: u32,
826 }
827
828 let query: Query<Pagination> = Query::from_request(&request).unwrap();
829 assert_eq!(query.page, 1);
830 assert_eq!(query.limit, 10);
831 }
832
833 #[test]
834 fn test_body_extraction() {
835 let mut request = create_request();
836 request.body = serde_json::to_vec(&serde_json::json!({
837 "name": "Test",
838 "email": "test@example.com"
839 }))
840 .unwrap();
841
842 #[derive(Debug, Deserialize)]
843 struct CreateUser {
844 name: String,
845 email: String,
846 }
847
848 let body: Body<CreateUser> = Body::from_request(&request).unwrap();
849 assert_eq!(body.name, "Test");
850 assert_eq!(body.email, "test@example.com");
851 }
852
853 #[test]
854 fn test_raw_body() {
855 let mut request = create_request();
856 request.body = b"raw content".to_vec();
857
858 let raw: RawBody = RawBody::from_request(&request).unwrap();
859 assert_eq!(raw.len(), 11);
860 assert_eq!(raw.to_string_lossy(), "raw content");
861 }
862
863 #[test]
864 fn test_content_type() {
865 let request = create_request();
866 let ct: ContentType = ContentType::from_request(&request).unwrap();
867
868 assert!(ct.is_json());
869 assert!(!ct.is_form());
870 assert!(!ct.is_multipart());
871 }
872
873 #[test]
874 fn test_method() {
875 let request = create_request();
876 let method: Method = Method::from_request(&request).unwrap();
877
878 assert!(method.is_get());
879 assert!(!method.is_post());
880 }
881
882 #[test]
883 fn test_request_extraction() {
884 let request = create_request();
885 let extracted: HttpRequest = HttpRequest::from_request(&request).unwrap();
886
887 assert_eq!(extracted.method, request.method);
888 assert_eq!(extracted.path, request.path);
889 }
890}