1use std::{convert::Infallible, fmt, marker::PhantomData};
6
7use bytes::Bytes;
8use faststr::FastStr;
9use futures_util::Future;
10use http::{
11 header::{self, HeaderMap, HeaderName},
12 method::Method,
13 request::Parts,
14 status::StatusCode,
15 uri::{Scheme, Uri},
16};
17use http_body::Body;
18use http_body_util::BodyExt;
19use volo::{context::Context, net::Address};
20
21use super::IntoResponse;
22use crate::{
23 context::ServerContext,
24 error::server::{ExtractBodyError, body_collection_error},
25 request::{Request, RequestPartsExt},
26 server::utils::client_ip::ClientIp,
27 utils::macros::impl_deref_and_deref_mut,
28};
29
30mod private {
31 #[derive(Debug, Clone, Copy)]
32 pub enum ViaContext {}
33
34 #[derive(Debug, Clone, Copy)]
35 pub enum ViaRequest {}
36}
37
38pub trait FromContext: Sized {
46 type Rejection: IntoResponse;
52
53 fn from_context(
55 cx: &mut ServerContext,
56 parts: &mut Parts,
57 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send;
58}
59
60pub trait FromRequest<B = crate::body::Body, M = private::ViaRequest>: Sized {
68 type Rejection: IntoResponse;
74
75 fn from_request(
77 cx: &mut ServerContext,
78 parts: Parts,
79 body: B,
80 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send;
81}
82
83#[cfg(feature = "query")]
87#[derive(Debug, Default, Clone, Copy)]
88pub struct Query<T>(pub T);
89
90#[cfg(feature = "form")]
94#[derive(Debug, Default, Clone, Copy)]
95pub struct Form<T>(pub T);
96
97#[cfg(feature = "json")]
149#[derive(Debug, Default, Clone, Copy)]
150pub struct Json<T>(pub T);
151
152#[derive(Debug, Default, Clone)]
158pub struct MaybeInvalid<T>(Vec<u8>, PhantomData<T>);
159
160impl MaybeInvalid<String> {
161 pub unsafe fn assume_valid(self) -> String {
168 unsafe { String::from_utf8_unchecked(self.0) }
169 }
170}
171
172impl MaybeInvalid<FastStr> {
173 pub unsafe fn assume_valid(self) -> FastStr {
180 unsafe { FastStr::from_vec_u8_unchecked(self.0) }
181 }
182}
183
184impl<T> FromContext for Option<T>
185where
186 T: FromContext,
187{
188 type Rejection = Infallible;
189
190 async fn from_context(
191 cx: &mut ServerContext,
192 parts: &mut Parts,
193 ) -> Result<Self, Self::Rejection> {
194 Ok(T::from_context(cx, parts).await.ok())
195 }
196}
197
198impl<T> FromContext for Result<T, T::Rejection>
199where
200 T: FromContext,
201{
202 type Rejection = Infallible;
203
204 async fn from_context(
205 cx: &mut ServerContext,
206 parts: &mut Parts,
207 ) -> Result<Self, Self::Rejection> {
208 Ok(T::from_context(cx, parts).await)
209 }
210}
211
212impl FromContext for Address {
213 type Rejection = Infallible;
214
215 async fn from_context(
216 cx: &mut ServerContext,
217 _parts: &mut Parts,
218 ) -> Result<Address, Self::Rejection> {
219 Ok(cx
220 .rpc_info()
221 .caller()
222 .address()
223 .expect("server context does not have caller address"))
224 }
225}
226
227impl FromContext for Uri {
228 type Rejection = Infallible;
229
230 async fn from_context(
231 _cx: &mut ServerContext,
232 parts: &mut Parts,
233 ) -> Result<Uri, Self::Rejection> {
234 Ok(parts.uri.to_owned())
235 }
236}
237
238#[derive(Debug)]
240pub struct FullUri(Uri);
241
242impl_deref_and_deref_mut!(FullUri, Uri, 0);
243
244impl fmt::Display for FullUri {
245 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246 write!(f, "{}", self.0)
247 }
248}
249
250impl FromContext for FullUri {
251 type Rejection = http::Error;
252
253 async fn from_context(
254 cx: &mut ServerContext,
255 parts: &mut Parts,
256 ) -> Result<Self, Self::Rejection> {
257 let scheme = if is_tls(cx) {
258 Scheme::HTTPS
259 } else {
260 Scheme::HTTP
261 };
262 Uri::builder()
263 .scheme(scheme)
264 .authority(parts.host().map(ToOwned::to_owned).unwrap_or_default())
265 .path_and_query(
266 parts
267 .uri
268 .path_and_query()
269 .map(ToString::to_string)
270 .unwrap_or(String::from("/")),
271 )
272 .build()
273 .map(FullUri)
274 }
275}
276
277impl IntoResponse for http::Error {
278 fn into_response(self) -> crate::response::Response {
279 StatusCode::INTERNAL_SERVER_ERROR.into_response()
280 }
281}
282
283impl FromContext for Method {
284 type Rejection = Infallible;
285
286 async fn from_context(
287 _cx: &mut ServerContext,
288 parts: &mut Parts,
289 ) -> Result<Method, Self::Rejection> {
290 Ok(parts.method.to_owned())
291 }
292}
293
294impl FromContext for ClientIp {
295 type Rejection = Infallible;
296
297 async fn from_context(cx: &mut ServerContext, _: &mut Parts) -> Result<Self, Self::Rejection> {
298 if let Some(client_ip) = cx.extensions().get::<ClientIp>() {
299 Ok(client_ip.to_owned())
300 } else {
301 Ok(ClientIp(None))
302 }
303 }
304}
305
306#[cfg(feature = "query")]
307impl<T> FromContext for Query<T>
308where
309 T: serde::de::DeserializeOwned,
310{
311 type Rejection = serde_urlencoded::de::Error;
312
313 async fn from_context(
314 _cx: &mut ServerContext,
315 parts: &mut Parts,
316 ) -> Result<Self, Self::Rejection> {
317 let query = parts.uri.query().unwrap_or_default();
318 let param = serde_urlencoded::from_str(query)?;
319 Ok(Query(param))
320 }
321}
322
323#[cfg(feature = "query")]
324impl IntoResponse for serde_urlencoded::de::Error {
325 fn into_response(self) -> crate::response::Response {
326 StatusCode::BAD_REQUEST.into_response()
327 }
328}
329
330impl<B, T> FromRequest<B, private::ViaContext> for T
331where
332 B: Send,
333 T: FromContext + Sync,
334{
335 type Rejection = T::Rejection;
336
337 async fn from_request(
338 cx: &mut ServerContext,
339 mut parts: Parts,
340 _: B,
341 ) -> Result<Self, Self::Rejection> {
342 T::from_context(cx, &mut parts).await
343 }
344}
345
346impl<B, T> FromRequest<B> for Option<T>
347where
348 B: Send,
349 T: FromRequest<B, private::ViaRequest> + Sync,
350{
351 type Rejection = Infallible;
352
353 async fn from_request(
354 cx: &mut ServerContext,
355 parts: Parts,
356 body: B,
357 ) -> Result<Self, Self::Rejection> {
358 Ok(T::from_request(cx, parts, body).await.ok())
359 }
360}
361
362impl<B, T> FromRequest<B> for Result<T, T::Rejection>
363where
364 B: Send,
365 T: FromRequest<B, private::ViaRequest> + Sync,
366{
367 type Rejection = Infallible;
368
369 async fn from_request(
370 cx: &mut ServerContext,
371 parts: Parts,
372 body: B,
373 ) -> Result<Self, Self::Rejection> {
374 Ok(T::from_request(cx, parts, body).await)
375 }
376}
377
378impl<B> FromRequest<B> for Request<B>
379where
380 B: Send,
381{
382 type Rejection = Infallible;
383
384 async fn from_request(
385 _cx: &mut ServerContext,
386 parts: Parts,
387 body: B,
388 ) -> Result<Self, Self::Rejection> {
389 Ok(Request::from_parts(parts, body))
390 }
391}
392
393impl<B> FromRequest<B> for Vec<u8>
394where
395 B: Body + Send,
396 B::Data: Send,
397 B::Error: Send,
398{
399 type Rejection = ExtractBodyError;
400
401 async fn from_request(
402 cx: &mut ServerContext,
403 parts: Parts,
404 body: B,
405 ) -> Result<Self, Self::Rejection> {
406 Ok(Bytes::from_request(cx, parts, body).await?.into())
407 }
408}
409
410impl<B> FromRequest<B> for Bytes
411where
412 B: Body + Send,
413 B::Data: Send,
414 B::Error: Send,
415{
416 type Rejection = ExtractBodyError;
417
418 async fn from_request(
419 _: &mut ServerContext,
420 parts: Parts,
421 body: B,
422 ) -> Result<Self, Self::Rejection> {
423 let bytes = body
424 .collect()
425 .await
426 .map_err(|_| body_collection_error())?
427 .to_bytes();
428
429 if let Some(cap) = get_header_value(&parts.headers, header::CONTENT_LENGTH) {
430 if let Ok(cap) = cap.parse::<usize>()
431 && bytes.len() != cap
432 {
433 tracing::warn!(
434 "[Volo-HTTP] The length of body ({}) does not match the Content-Length ({cap})",
435 bytes.len(),
436 );
437 }
438 }
439
440 Ok(bytes)
441 }
442}
443
444impl<B> FromRequest<B> for String
445where
446 B: Body + Send,
447 B::Data: Send,
448 B::Error: Send,
449{
450 type Rejection = ExtractBodyError;
451
452 async fn from_request(
453 cx: &mut ServerContext,
454 parts: Parts,
455 body: B,
456 ) -> Result<Self, Self::Rejection> {
457 let vec = Vec::<u8>::from_request(cx, parts, body).await?;
458
459 let _ = simdutf8::basic::from_utf8(&vec).map_err(ExtractBodyError::String)?;
461
462 Ok(unsafe { String::from_utf8_unchecked(vec) })
464 }
465}
466
467impl<B> FromRequest<B> for FastStr
468where
469 B: Body + Send,
470 B::Data: Send,
471 B::Error: Send,
472{
473 type Rejection = ExtractBodyError;
474
475 async fn from_request(
476 cx: &mut ServerContext,
477 parts: Parts,
478 body: B,
479 ) -> Result<Self, Self::Rejection> {
480 let vec = Vec::<u8>::from_request(cx, parts, body).await?;
481
482 let _ = simdutf8::basic::from_utf8(&vec).map_err(ExtractBodyError::String)?;
484
485 Ok(unsafe { FastStr::from_vec_u8_unchecked(vec) })
487 }
488}
489
490impl<B, T> FromRequest<B> for MaybeInvalid<T>
491where
492 B: Body + Send,
493 B::Data: Send,
494 B::Error: Send,
495{
496 type Rejection = ExtractBodyError;
497
498 async fn from_request(
499 cx: &mut ServerContext,
500 parts: Parts,
501 body: B,
502 ) -> Result<Self, Self::Rejection> {
503 let vec = Vec::<u8>::from_request(cx, parts, body).await?;
504
505 Ok(MaybeInvalid(vec, PhantomData))
506 }
507}
508
509#[cfg(feature = "form")]
510impl<B, T> FromRequest<B> for Form<T>
511where
512 B: Body + Send,
513 B::Data: Send,
514 B::Error: Send,
515 T: serde::de::DeserializeOwned,
516{
517 type Rejection = ExtractBodyError;
518
519 async fn from_request(
520 cx: &mut ServerContext,
521 parts: Parts,
522 body: B,
523 ) -> Result<Self, Self::Rejection> {
524 if !content_type_matches(&parts.headers, mime::APPLICATION, mime::WWW_FORM_URLENCODED) {
525 return Err(crate::error::server::invalid_content_type());
526 }
527
528 let bytes = Bytes::from_request(cx, parts, body).await?;
529 let form =
530 serde_urlencoded::from_bytes::<T>(bytes.as_ref()).map_err(ExtractBodyError::Form)?;
531
532 Ok(Form(form))
533 }
534}
535
536#[cfg(feature = "json")]
537impl<B, T> FromRequest<B> for Json<T>
538where
539 B: Body + Send,
540 B::Data: Send,
541 B::Error: Send,
542 T: serde::de::DeserializeOwned,
543{
544 type Rejection = ExtractBodyError;
545
546 async fn from_request(
547 cx: &mut ServerContext,
548 parts: Parts,
549 body: B,
550 ) -> Result<Self, Self::Rejection> {
551 if !content_type_matches(&parts.headers, mime::APPLICATION, mime::JSON) {
552 return Err(crate::error::server::invalid_content_type());
553 }
554
555 let bytes = Bytes::from_request(cx, parts, body).await?;
556 let json = crate::utils::json::deserialize(&bytes).map_err(ExtractBodyError::Json)?;
557
558 Ok(Json(json))
559 }
560}
561
562#[cfg(not(feature = "__tls"))]
563fn is_tls(_: &ServerContext) -> bool {
564 false
565}
566
567#[cfg(feature = "__tls")]
568fn is_tls(cx: &ServerContext) -> bool {
569 cx.rpc_info().config().is_tls()
570}
571
572fn get_header_value(map: &HeaderMap, key: HeaderName) -> Option<&str> {
573 map.get(key)?.to_str().ok()
574}
575
576#[cfg(any(feature = "form", feature = "json"))]
577fn content_type_matches(
578 headers: &HeaderMap,
579 ty: mime::Name<'static>,
580 subtype: mime::Name<'static>,
581) -> bool {
582 use std::str::FromStr;
583
584 let Some(content_type) = headers.get(header::CONTENT_TYPE) else {
585 return false;
586 };
587 let Ok(content_type) = content_type.to_str() else {
588 return false;
589 };
590 let Ok(mime) = mime::Mime::from_str(content_type) else {
591 return false;
592 };
593
594 (mime.type_() == ty && mime.subtype() == subtype) || mime.suffix() == Some(subtype)
596}
597
598#[cfg(test)]
599mod extract_tests {
600 #![deny(unused)]
601
602 use std::convert::Infallible;
603
604 use http::request::Parts;
605
606 use super::{FromContext, FromRequest};
607 use crate::{body::Body, context::ServerContext, server::handler::Handler};
608
609 struct SomethingFromCx;
610
611 impl FromContext for SomethingFromCx {
612 type Rejection = Infallible;
613 async fn from_context(
614 _: &mut ServerContext,
615 _: &mut Parts,
616 ) -> Result<Self, Self::Rejection> {
617 unimplemented!()
618 }
619 }
620
621 struct SomethingFromReq;
622
623 impl FromRequest for SomethingFromReq {
624 type Rejection = Infallible;
625 async fn from_request(
626 _: &mut ServerContext,
627 _: Parts,
628 _: Body,
629 ) -> Result<Self, Self::Rejection> {
630 unimplemented!()
631 }
632 }
633
634 #[test]
635 fn extractor() {
636 fn assert_handler<H, T>(_: H)
637 where
638 H: Handler<T, Body, Infallible>,
639 {
640 }
641
642 async fn only_cx(_: SomethingFromCx) {}
643 async fn only_req(_: SomethingFromReq) {}
644 async fn cx_and_req(_: SomethingFromCx, _: SomethingFromReq) {}
645 async fn many_cx_and_req(
646 _: SomethingFromCx,
647 _: SomethingFromCx,
648 _: SomethingFromCx,
649 _: SomethingFromReq,
650 ) {
651 }
652 async fn only_option_cx(_: Option<SomethingFromCx>) {}
653 async fn only_option_req(_: Option<SomethingFromReq>) {}
654 async fn only_result_cx(_: Result<SomethingFromCx, Infallible>) {}
655 async fn only_result_req(_: Result<SomethingFromReq, Infallible>) {}
656 async fn option_cx_req(_: Option<SomethingFromCx>, _: Option<SomethingFromReq>) {}
657 async fn result_cx_req(
658 _: Result<SomethingFromCx, Infallible>,
659 _: Result<SomethingFromReq, Infallible>,
660 ) {
661 }
662
663 assert_handler(only_cx);
664 assert_handler(only_req);
665 assert_handler(cx_and_req);
666 assert_handler(many_cx_and_req);
667 assert_handler(only_option_cx);
668 assert_handler(only_option_req);
669 assert_handler(only_result_cx);
670 assert_handler(only_result_req);
671 assert_handler(option_cx_req);
672 assert_handler(result_cx_req);
673 }
674
675 #[cfg(any(feature = "form", feature = "json"))]
676 fn simple_req(content_type: &'static str, body: &'static str) -> crate::request::Request {
677 let mut req = crate::request::Request::new(Body::from(body));
678 req.headers_mut().insert(
679 http::header::CONTENT_TYPE,
680 http::header::HeaderValue::from_static(content_type),
681 );
682 req
683 }
684
685 #[cfg(feature = "form")]
686 #[tokio::test]
687 async fn extract_form() {
688 use crate::server::test_helpers;
689
690 #[derive(Debug, PartialEq, Eq, serde::Deserialize)]
691 struct TestForm {
692 key1: String,
693 key2: String,
694 key3: String,
695 }
696
697 const VALID_FORM: &str = "key1=value1&key2=value2&key3=value3";
698 const INVALID_FORM: &str = "if (key && value) { print(key, value) }";
699
700 let test_form = serde_urlencoded::from_str(VALID_FORM).unwrap();
701
702 {
704 let req = simple_req("application/x-www-form-urlencoded", VALID_FORM);
705 let (parts, body) = req.into_parts();
706 assert_eq!(
707 super::Form::<TestForm>::from_request(&mut test_helpers::empty_cx(), parts, body,)
708 .await
709 .unwrap()
710 .0,
711 test_form,
712 );
713 }
714 {
716 let req = simple_req(
717 "application/x-www-form-urlencoded; charset=utf-8",
718 VALID_FORM,
719 );
720 let (parts, body) = req.into_parts();
721 assert_eq!(
722 super::Form::<TestForm>::from_request(&mut test_helpers::empty_cx(), parts, body,)
723 .await
724 .unwrap()
725 .0,
726 test_form,
727 );
728 }
729 {
731 let req = simple_req("text/javascript", VALID_FORM);
732 let (parts, body) = req.into_parts();
733 super::Form::<TestForm>::from_request(&mut test_helpers::empty_cx(), parts, body)
734 .await
735 .unwrap_err();
736 }
737 {
739 let req = simple_req("application/x-www-form-urlencoded", INVALID_FORM);
740 let (parts, body) = req.into_parts();
741 super::Form::<TestForm>::from_request(&mut test_helpers::empty_cx(), parts, body)
742 .await
743 .unwrap_err();
744 }
745 }
746
747 #[cfg(feature = "json")]
748 #[tokio::test]
749 async fn extract_json() {
750 use crate::server::test_helpers;
751
752 #[derive(Debug, PartialEq, Eq, serde::Deserialize)]
753 struct TestJson {
754 key1: String,
755 key2: String,
756 key3: String,
757 }
758
759 const VALID_JSON: &str = r#"{"key1":"value1","key2":"value2", "key3": "value3"}"#;
760 const INVALID_JSON: &str = "if (key && value) { print(key, value) }";
761
762 let test_json = crate::utils::json::deserialize(VALID_JSON.as_bytes()).unwrap();
763
764 {
766 let req = simple_req("application/json", VALID_JSON);
767 let (parts, body) = req.into_parts();
768 assert_eq!(
769 super::Json::<TestJson>::from_request(&mut test_helpers::empty_cx(), parts, body,)
770 .await
771 .unwrap()
772 .0,
773 test_json,
774 );
775 }
776 {
778 let req = simple_req("application/json; charset=utf-8", VALID_JSON);
779 let (parts, body) = req.into_parts();
780 assert_eq!(
781 super::Json::<TestJson>::from_request(&mut test_helpers::empty_cx(), parts, body,)
782 .await
783 .unwrap()
784 .0,
785 test_json,
786 );
787 }
788 {
790 let req = simple_req("text/javascript", VALID_JSON);
791 let (parts, body) = req.into_parts();
792 super::Json::<TestJson>::from_request(&mut test_helpers::empty_cx(), parts, body)
793 .await
794 .unwrap_err();
795 }
796 {
798 let req = simple_req("application/json", INVALID_JSON);
799 let (parts, body) = req.into_parts();
800 super::Json::<TestJson>::from_request(&mut test_helpers::empty_cx(), parts, body)
801 .await
802 .unwrap_err();
803 }
804 }
805}