volo_http/server/
extract.rs

1//! Traits and types for extracting data from [`ServerContext`] and [`Request`]
2//!
3//! See [`FromContext`] and [`FromRequest`] for more details.
4
5use std::{convert::Infallible, fmt, marker::PhantomData};
6
7use bytes::Bytes;
8use faststr::FastStr;
9use futures_util::Future;
10use http::{
11    header::{self, HeaderMap, HeaderName},
12    method::Method,
13    request::Parts,
14    status::StatusCode,
15    uri::{Scheme, Uri},
16};
17use http_body::Body;
18use http_body_util::BodyExt;
19use volo::{context::Context, net::Address};
20
21use super::IntoResponse;
22use crate::{
23    context::ServerContext,
24    error::server::{ExtractBodyError, body_collection_error},
25    request::{Request, RequestPartsExt},
26    server::utils::client_ip::ClientIp,
27    utils::macros::impl_deref_and_deref_mut,
28};
29
30mod private {
31    #[derive(Debug, Clone, Copy)]
32    pub enum ViaContext {}
33
34    #[derive(Debug, Clone, Copy)]
35    pub enum ViaRequest {}
36}
37
38/// Extract a type from context ([`ServerContext`] and [`Parts`])
39///
40/// This trait is used for handlers, which can extract something from [`ServerContext`] and
41/// [`Request`].
42///
43/// [`FromContext`] only borrows [`ServerContext`] and [`Parts`]. If your extractor needs to
44/// consume [`Parts`] or the whole [`Request`], please use [`FromRequest`] instead.
45pub trait FromContext: Sized {
46    /// If the extractor fails, it will return this `Rejection` type.
47    ///
48    /// The `Rejection` should implement [`IntoResponse`]. If extractor fails in handler, the
49    /// rejection will be converted into a [`Response`](crate::response::Response) and
50    /// returned.
51    type Rejection: IntoResponse;
52
53    /// Extract the type from context.
54    fn from_context(
55        cx: &mut ServerContext,
56        parts: &mut Parts,
57    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send;
58}
59
60/// Extract a type from [`Request`] with its [`ServerContext`]
61///
62/// This trait is used for handlers, which can extract something from [`ServerContext`] and
63/// [`Request`].
64///
65/// [`FromRequest`] will consume [`Request`], so it can only be used once in a handler. If
66/// your extractor does not need to consume [`Request`], please use [`FromContext`] instead.
67pub trait FromRequest<B = crate::body::Body, M = private::ViaRequest>: Sized {
68    /// If the extractor fails, it will return this `Rejection` type.
69    ///
70    /// The `Rejection` should implement [`IntoResponse`]. If extractor fails in handler, the
71    /// rejection will be converted into a [`Response`](crate::response::Response) and
72    /// returned.
73    type Rejection: IntoResponse;
74
75    /// Extract the type from request.
76    fn from_request(
77        cx: &mut ServerContext,
78        parts: Parts,
79        body: B,
80    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send;
81}
82
83/// Extract a type from query in uri.
84///
85/// Note that the type must implement [`Deserialize`](serde::Deserialize).
86#[cfg(feature = "query")]
87#[derive(Debug, Default, Clone, Copy)]
88pub struct Query<T>(pub T);
89
90/// Extract a type from a urlencoded body.
91///
92/// Note that the type must implement [`Deserialize`](serde::Deserialize).
93#[cfg(feature = "form")]
94#[derive(Debug, Default, Clone, Copy)]
95pub struct Form<T>(pub T);
96
97/// A wrapper that can extract a type from a json body or convert a type to json response.
98///
99/// # Examples
100///
101/// Use [`Json`] as parameter:
102///
103/// ```
104/// use serde::Deserialize;
105/// use volo_http::server::{
106///     extract::Json,
107///     route::{Router, post},
108/// };
109///
110/// #[derive(Debug, Deserialize)]
111/// struct User {
112///     username: String,
113///     password: String,
114/// }
115///
116/// async fn login(Json(user): Json<User>) {
117///     println!("user: {user:?}");
118/// }
119///
120/// let router: Router = Router::new().route("/api/v2/login", post(login));
121/// ```
122///
123/// User [`Json`] as response:
124///
125/// ```
126/// use serde::Serialize;
127/// use volo_http::server::{
128///     extract::Json,
129///     route::{Router, get},
130/// };
131///
132/// #[derive(Debug, Serialize)]
133/// struct User {
134///     username: String,
135///     password: String,
136/// }
137///
138/// async fn user_info() -> Json<User> {
139///     let user = User {
140///         username: String::from("admin"),
141///         password: String::from("passw0rd"),
142///     };
143///     Json(user)
144/// }
145///
146/// let router: Router = Router::new().route("/api/v2/info", get(user_info));
147/// ```
148#[cfg(feature = "json")]
149#[derive(Debug, Default, Clone, Copy)]
150pub struct Json<T>(pub T);
151
152/// Extract a [`String`] or [`FastStr`] without checking.
153///
154/// This type can extract a [`String`] or [`FastStr`] like [`String::from_utf8_unchecked`] or
155/// [`FastStr::from_vec_u8_unchecked`]. Note that extracting them is unsafe and users should assume
156/// that the value is valid.
157#[derive(Debug, Default, Clone)]
158pub struct MaybeInvalid<T>(Vec<u8>, PhantomData<T>);
159
160impl MaybeInvalid<String> {
161    /// Assume the [`String`] is valid and extract it without checking.
162    ///
163    /// # Safety
164    ///
165    /// It is up to the caller to guarantee that the value really is valid. Using this when the
166    /// content is invalid causes immediate undefined behavior.
167    pub unsafe fn assume_valid(self) -> String {
168        unsafe { String::from_utf8_unchecked(self.0) }
169    }
170}
171
172impl MaybeInvalid<FastStr> {
173    /// Assume the [`FastStr`] is valid and extract it without checking.
174    ///
175    /// # Safety
176    ///
177    /// It is up to the caller to guarantee that the value really is valid. Using this when the
178    /// content is invalid causes immediate undefined behavior.
179    pub unsafe fn assume_valid(self) -> FastStr {
180        unsafe { FastStr::from_vec_u8_unchecked(self.0) }
181    }
182}
183
184impl<T> FromContext for Option<T>
185where
186    T: FromContext,
187{
188    type Rejection = Infallible;
189
190    async fn from_context(
191        cx: &mut ServerContext,
192        parts: &mut Parts,
193    ) -> Result<Self, Self::Rejection> {
194        Ok(T::from_context(cx, parts).await.ok())
195    }
196}
197
198impl<T> FromContext for Result<T, T::Rejection>
199where
200    T: FromContext,
201{
202    type Rejection = Infallible;
203
204    async fn from_context(
205        cx: &mut ServerContext,
206        parts: &mut Parts,
207    ) -> Result<Self, Self::Rejection> {
208        Ok(T::from_context(cx, parts).await)
209    }
210}
211
212impl FromContext for Address {
213    type Rejection = Infallible;
214
215    async fn from_context(
216        cx: &mut ServerContext,
217        _parts: &mut Parts,
218    ) -> Result<Address, Self::Rejection> {
219        Ok(cx
220            .rpc_info()
221            .caller()
222            .address()
223            .expect("server context does not have caller address"))
224    }
225}
226
227impl FromContext for Uri {
228    type Rejection = Infallible;
229
230    async fn from_context(
231        _cx: &mut ServerContext,
232        parts: &mut Parts,
233    ) -> Result<Uri, Self::Rejection> {
234        Ok(parts.uri.to_owned())
235    }
236}
237
238/// Full uri including scheme, host, path and query.
239#[derive(Debug)]
240pub struct FullUri(Uri);
241
242impl_deref_and_deref_mut!(FullUri, Uri, 0);
243
244impl fmt::Display for FullUri {
245    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246        write!(f, "{}", self.0)
247    }
248}
249
250impl FromContext for FullUri {
251    type Rejection = http::Error;
252
253    async fn from_context(
254        cx: &mut ServerContext,
255        parts: &mut Parts,
256    ) -> Result<Self, Self::Rejection> {
257        let scheme = if is_tls(cx) {
258            Scheme::HTTPS
259        } else {
260            Scheme::HTTP
261        };
262        Uri::builder()
263            .scheme(scheme)
264            .authority(parts.host().map(ToOwned::to_owned).unwrap_or_default())
265            .path_and_query(
266                parts
267                    .uri
268                    .path_and_query()
269                    .map(ToString::to_string)
270                    .unwrap_or(String::from("/")),
271            )
272            .build()
273            .map(FullUri)
274    }
275}
276
277impl IntoResponse for http::Error {
278    fn into_response(self) -> crate::response::Response {
279        StatusCode::INTERNAL_SERVER_ERROR.into_response()
280    }
281}
282
283impl FromContext for Method {
284    type Rejection = Infallible;
285
286    async fn from_context(
287        _cx: &mut ServerContext,
288        parts: &mut Parts,
289    ) -> Result<Method, Self::Rejection> {
290        Ok(parts.method.to_owned())
291    }
292}
293
294impl FromContext for ClientIp {
295    type Rejection = Infallible;
296
297    async fn from_context(cx: &mut ServerContext, _: &mut Parts) -> Result<Self, Self::Rejection> {
298        if let Some(client_ip) = cx.extensions().get::<ClientIp>() {
299            Ok(client_ip.to_owned())
300        } else {
301            Ok(ClientIp(None))
302        }
303    }
304}
305
306#[cfg(feature = "query")]
307impl<T> FromContext for Query<T>
308where
309    T: serde::de::DeserializeOwned,
310{
311    type Rejection = serde_urlencoded::de::Error;
312
313    async fn from_context(
314        _cx: &mut ServerContext,
315        parts: &mut Parts,
316    ) -> Result<Self, Self::Rejection> {
317        let query = parts.uri.query().unwrap_or_default();
318        let param = serde_urlencoded::from_str(query)?;
319        Ok(Query(param))
320    }
321}
322
323#[cfg(feature = "query")]
324impl IntoResponse for serde_urlencoded::de::Error {
325    fn into_response(self) -> crate::response::Response {
326        StatusCode::BAD_REQUEST.into_response()
327    }
328}
329
330impl<B, T> FromRequest<B, private::ViaContext> for T
331where
332    B: Send,
333    T: FromContext + Sync,
334{
335    type Rejection = T::Rejection;
336
337    async fn from_request(
338        cx: &mut ServerContext,
339        mut parts: Parts,
340        _: B,
341    ) -> Result<Self, Self::Rejection> {
342        T::from_context(cx, &mut parts).await
343    }
344}
345
346impl<B, T> FromRequest<B> for Option<T>
347where
348    B: Send,
349    T: FromRequest<B, private::ViaRequest> + Sync,
350{
351    type Rejection = Infallible;
352
353    async fn from_request(
354        cx: &mut ServerContext,
355        parts: Parts,
356        body: B,
357    ) -> Result<Self, Self::Rejection> {
358        Ok(T::from_request(cx, parts, body).await.ok())
359    }
360}
361
362impl<B, T> FromRequest<B> for Result<T, T::Rejection>
363where
364    B: Send,
365    T: FromRequest<B, private::ViaRequest> + Sync,
366{
367    type Rejection = Infallible;
368
369    async fn from_request(
370        cx: &mut ServerContext,
371        parts: Parts,
372        body: B,
373    ) -> Result<Self, Self::Rejection> {
374        Ok(T::from_request(cx, parts, body).await)
375    }
376}
377
378impl<B> FromRequest<B> for Request<B>
379where
380    B: Send,
381{
382    type Rejection = Infallible;
383
384    async fn from_request(
385        _cx: &mut ServerContext,
386        parts: Parts,
387        body: B,
388    ) -> Result<Self, Self::Rejection> {
389        Ok(Request::from_parts(parts, body))
390    }
391}
392
393impl<B> FromRequest<B> for Vec<u8>
394where
395    B: Body + Send,
396    B::Data: Send,
397    B::Error: Send,
398{
399    type Rejection = ExtractBodyError;
400
401    async fn from_request(
402        cx: &mut ServerContext,
403        parts: Parts,
404        body: B,
405    ) -> Result<Self, Self::Rejection> {
406        Ok(Bytes::from_request(cx, parts, body).await?.into())
407    }
408}
409
410impl<B> FromRequest<B> for Bytes
411where
412    B: Body + Send,
413    B::Data: Send,
414    B::Error: Send,
415{
416    type Rejection = ExtractBodyError;
417
418    async fn from_request(
419        _: &mut ServerContext,
420        parts: Parts,
421        body: B,
422    ) -> Result<Self, Self::Rejection> {
423        let bytes = body
424            .collect()
425            .await
426            .map_err(|_| body_collection_error())?
427            .to_bytes();
428
429        if let Some(cap) = get_header_value(&parts.headers, header::CONTENT_LENGTH) {
430            if let Ok(cap) = cap.parse::<usize>()
431                && bytes.len() != cap
432            {
433                tracing::warn!(
434                    "[Volo-HTTP] The length of body ({}) does not match the Content-Length ({cap})",
435                    bytes.len(),
436                );
437            }
438        }
439
440        Ok(bytes)
441    }
442}
443
444impl<B> FromRequest<B> for String
445where
446    B: Body + Send,
447    B::Data: Send,
448    B::Error: Send,
449{
450    type Rejection = ExtractBodyError;
451
452    async fn from_request(
453        cx: &mut ServerContext,
454        parts: Parts,
455        body: B,
456    ) -> Result<Self, Self::Rejection> {
457        let vec = Vec::<u8>::from_request(cx, parts, body).await?;
458
459        // Check if the &[u8] is a valid string
460        let _ = simdutf8::basic::from_utf8(&vec).map_err(ExtractBodyError::String)?;
461
462        // SAFETY: The `Vec<u8>` is checked by `simdutf8` and it is a valid `String`
463        Ok(unsafe { String::from_utf8_unchecked(vec) })
464    }
465}
466
467impl<B> FromRequest<B> for FastStr
468where
469    B: Body + Send,
470    B::Data: Send,
471    B::Error: Send,
472{
473    type Rejection = ExtractBodyError;
474
475    async fn from_request(
476        cx: &mut ServerContext,
477        parts: Parts,
478        body: B,
479    ) -> Result<Self, Self::Rejection> {
480        let vec = Vec::<u8>::from_request(cx, parts, body).await?;
481
482        // Check if the &[u8] is a valid string
483        let _ = simdutf8::basic::from_utf8(&vec).map_err(ExtractBodyError::String)?;
484
485        // SAFETY: The `Vec<u8>` is checked by `simdutf8` and it is a valid `String`
486        Ok(unsafe { FastStr::from_vec_u8_unchecked(vec) })
487    }
488}
489
490impl<B, T> FromRequest<B> for MaybeInvalid<T>
491where
492    B: Body + Send,
493    B::Data: Send,
494    B::Error: Send,
495{
496    type Rejection = ExtractBodyError;
497
498    async fn from_request(
499        cx: &mut ServerContext,
500        parts: Parts,
501        body: B,
502    ) -> Result<Self, Self::Rejection> {
503        let vec = Vec::<u8>::from_request(cx, parts, body).await?;
504
505        Ok(MaybeInvalid(vec, PhantomData))
506    }
507}
508
509#[cfg(feature = "form")]
510impl<B, T> FromRequest<B> for Form<T>
511where
512    B: Body + Send,
513    B::Data: Send,
514    B::Error: Send,
515    T: serde::de::DeserializeOwned,
516{
517    type Rejection = ExtractBodyError;
518
519    async fn from_request(
520        cx: &mut ServerContext,
521        parts: Parts,
522        body: B,
523    ) -> Result<Self, Self::Rejection> {
524        if !content_type_matches(&parts.headers, mime::APPLICATION, mime::WWW_FORM_URLENCODED) {
525            return Err(crate::error::server::invalid_content_type());
526        }
527
528        let bytes = Bytes::from_request(cx, parts, body).await?;
529        let form =
530            serde_urlencoded::from_bytes::<T>(bytes.as_ref()).map_err(ExtractBodyError::Form)?;
531
532        Ok(Form(form))
533    }
534}
535
536#[cfg(feature = "json")]
537impl<B, T> FromRequest<B> for Json<T>
538where
539    B: Body + Send,
540    B::Data: Send,
541    B::Error: Send,
542    T: serde::de::DeserializeOwned,
543{
544    type Rejection = ExtractBodyError;
545
546    async fn from_request(
547        cx: &mut ServerContext,
548        parts: Parts,
549        body: B,
550    ) -> Result<Self, Self::Rejection> {
551        if !content_type_matches(&parts.headers, mime::APPLICATION, mime::JSON) {
552            return Err(crate::error::server::invalid_content_type());
553        }
554
555        let bytes = Bytes::from_request(cx, parts, body).await?;
556        let json = crate::utils::json::deserialize(&bytes).map_err(ExtractBodyError::Json)?;
557
558        Ok(Json(json))
559    }
560}
561
562#[cfg(not(feature = "__tls"))]
563fn is_tls(_: &ServerContext) -> bool {
564    false
565}
566
567#[cfg(feature = "__tls")]
568fn is_tls(cx: &ServerContext) -> bool {
569    cx.rpc_info().config().is_tls()
570}
571
572fn get_header_value(map: &HeaderMap, key: HeaderName) -> Option<&str> {
573    map.get(key)?.to_str().ok()
574}
575
576#[cfg(any(feature = "form", feature = "json"))]
577fn content_type_matches(
578    headers: &HeaderMap,
579    ty: mime::Name<'static>,
580    subtype: mime::Name<'static>,
581) -> bool {
582    use std::str::FromStr;
583
584    let Some(content_type) = headers.get(header::CONTENT_TYPE) else {
585        return false;
586    };
587    let Ok(content_type) = content_type.to_str() else {
588        return false;
589    };
590    let Ok(mime) = mime::Mime::from_str(content_type) else {
591        return false;
592    };
593
594    // `text/xml` or `image/svg+xml`
595    (mime.type_() == ty && mime.subtype() == subtype) || mime.suffix() == Some(subtype)
596}
597
598#[cfg(test)]
599mod extract_tests {
600    #![deny(unused)]
601
602    use std::convert::Infallible;
603
604    use http::request::Parts;
605
606    use super::{FromContext, FromRequest};
607    use crate::{body::Body, context::ServerContext, server::handler::Handler};
608
609    struct SomethingFromCx;
610
611    impl FromContext for SomethingFromCx {
612        type Rejection = Infallible;
613        async fn from_context(
614            _: &mut ServerContext,
615            _: &mut Parts,
616        ) -> Result<Self, Self::Rejection> {
617            unimplemented!()
618        }
619    }
620
621    struct SomethingFromReq;
622
623    impl FromRequest for SomethingFromReq {
624        type Rejection = Infallible;
625        async fn from_request(
626            _: &mut ServerContext,
627            _: Parts,
628            _: Body,
629        ) -> Result<Self, Self::Rejection> {
630            unimplemented!()
631        }
632    }
633
634    #[test]
635    fn extractor() {
636        fn assert_handler<H, T>(_: H)
637        where
638            H: Handler<T, Body, Infallible>,
639        {
640        }
641
642        async fn only_cx(_: SomethingFromCx) {}
643        async fn only_req(_: SomethingFromReq) {}
644        async fn cx_and_req(_: SomethingFromCx, _: SomethingFromReq) {}
645        async fn many_cx_and_req(
646            _: SomethingFromCx,
647            _: SomethingFromCx,
648            _: SomethingFromCx,
649            _: SomethingFromReq,
650        ) {
651        }
652        async fn only_option_cx(_: Option<SomethingFromCx>) {}
653        async fn only_option_req(_: Option<SomethingFromReq>) {}
654        async fn only_result_cx(_: Result<SomethingFromCx, Infallible>) {}
655        async fn only_result_req(_: Result<SomethingFromReq, Infallible>) {}
656        async fn option_cx_req(_: Option<SomethingFromCx>, _: Option<SomethingFromReq>) {}
657        async fn result_cx_req(
658            _: Result<SomethingFromCx, Infallible>,
659            _: Result<SomethingFromReq, Infallible>,
660        ) {
661        }
662
663        assert_handler(only_cx);
664        assert_handler(only_req);
665        assert_handler(cx_and_req);
666        assert_handler(many_cx_and_req);
667        assert_handler(only_option_cx);
668        assert_handler(only_option_req);
669        assert_handler(only_result_cx);
670        assert_handler(only_result_req);
671        assert_handler(option_cx_req);
672        assert_handler(result_cx_req);
673    }
674
675    #[cfg(any(feature = "form", feature = "json"))]
676    fn simple_req(content_type: &'static str, body: &'static str) -> crate::request::Request {
677        let mut req = crate::request::Request::new(Body::from(body));
678        req.headers_mut().insert(
679            http::header::CONTENT_TYPE,
680            http::header::HeaderValue::from_static(content_type),
681        );
682        req
683    }
684
685    #[cfg(feature = "form")]
686    #[tokio::test]
687    async fn extract_form() {
688        use crate::server::test_helpers;
689
690        #[derive(Debug, PartialEq, Eq, serde::Deserialize)]
691        struct TestForm {
692            key1: String,
693            key2: String,
694            key3: String,
695        }
696
697        const VALID_FORM: &str = "key1=value1&key2=value2&key3=value3";
698        const INVALID_FORM: &str = "if (key && value) { print(key, value) }";
699
700        let test_form = serde_urlencoded::from_str(VALID_FORM).unwrap();
701
702        // simple content-type
703        {
704            let req = simple_req("application/x-www-form-urlencoded", VALID_FORM);
705            let (parts, body) = req.into_parts();
706            assert_eq!(
707                super::Form::<TestForm>::from_request(&mut test_helpers::empty_cx(), parts, body,)
708                    .await
709                    .unwrap()
710                    .0,
711                test_form,
712            );
713        }
714        // content-type with charset
715        {
716            let req = simple_req(
717                "application/x-www-form-urlencoded; charset=utf-8",
718                VALID_FORM,
719            );
720            let (parts, body) = req.into_parts();
721            assert_eq!(
722                super::Form::<TestForm>::from_request(&mut test_helpers::empty_cx(), parts, body,)
723                    .await
724                    .unwrap()
725                    .0,
726                test_form,
727            );
728        }
729        // wrong content type
730        {
731            let req = simple_req("text/javascript", VALID_FORM);
732            let (parts, body) = req.into_parts();
733            super::Form::<TestForm>::from_request(&mut test_helpers::empty_cx(), parts, body)
734                .await
735                .unwrap_err();
736        }
737        // invalid form
738        {
739            let req = simple_req("application/x-www-form-urlencoded", INVALID_FORM);
740            let (parts, body) = req.into_parts();
741            super::Form::<TestForm>::from_request(&mut test_helpers::empty_cx(), parts, body)
742                .await
743                .unwrap_err();
744        }
745    }
746
747    #[cfg(feature = "json")]
748    #[tokio::test]
749    async fn extract_json() {
750        use crate::server::test_helpers;
751
752        #[derive(Debug, PartialEq, Eq, serde::Deserialize)]
753        struct TestJson {
754            key1: String,
755            key2: String,
756            key3: String,
757        }
758
759        const VALID_JSON: &str = r#"{"key1":"value1","key2":"value2", "key3": "value3"}"#;
760        const INVALID_JSON: &str = "if (key && value) { print(key, value) }";
761
762        let test_json = crate::utils::json::deserialize(VALID_JSON.as_bytes()).unwrap();
763
764        // simple content-type
765        {
766            let req = simple_req("application/json", VALID_JSON);
767            let (parts, body) = req.into_parts();
768            assert_eq!(
769                super::Json::<TestJson>::from_request(&mut test_helpers::empty_cx(), parts, body,)
770                    .await
771                    .unwrap()
772                    .0,
773                test_json,
774            );
775        }
776        // content-type with charset
777        {
778            let req = simple_req("application/json; charset=utf-8", VALID_JSON);
779            let (parts, body) = req.into_parts();
780            assert_eq!(
781                super::Json::<TestJson>::from_request(&mut test_helpers::empty_cx(), parts, body,)
782                    .await
783                    .unwrap()
784                    .0,
785                test_json,
786            );
787        }
788        // wrong content type
789        {
790            let req = simple_req("text/javascript", VALID_JSON);
791            let (parts, body) = req.into_parts();
792            super::Json::<TestJson>::from_request(&mut test_helpers::empty_cx(), parts, body)
793                .await
794                .unwrap_err();
795        }
796        // invalid form
797        {
798            let req = simple_req("application/json", INVALID_JSON);
799            let (parts, body) = req.into_parts();
800            super::Json::<TestJson>::from_request(&mut test_helpers::empty_cx(), parts, body)
801                .await
802                .unwrap_err();
803        }
804    }
805}