Skip to main content

dioxus_fullstack/
request.rs

1use dioxus_fullstack_core::{RequestError, ServerFnError};
2#[cfg(feature = "server")]
3use headers::Header;
4use http::response::Parts;
5use std::{future::Future, pin::Pin};
6
7use crate::{ClientRequest, ClientResponse, ErrorPayload};
8
9/// The `IntoRequest` trait allows types to be used as the body of a request to a HTTP endpoint or server function.
10///
11/// `IntoRequest` allows for types handle the calling of `ClientRequest::send` where the result is then
12/// passed to `FromResponse` to decode the response.
13///
14/// You can think of the `IntoRequest` and `FromResponse` traits are "inverse" traits of the axum
15/// `FromRequest` and `IntoResponse` traits. Just like a type can be decoded from a request via `FromRequest`,
16/// a type can be encoded into a request via `IntoRequest`.
17///
18/// ## Generic State
19///
20/// `IntoRequest` is generic over the response type `R` which defaults to `ClientResponse`. The default
21/// `ClientResponse` is the base response type that internally wraps `reqwest::Response`.
22///
23/// However, some responses might need state from the initial request to properly decode the response.
24/// Most state can be extended via the `.extension()` method on `ClientRequest`. In some cases, like
25/// websockets, the response needs to retain an initial connection from the request. Here, you can use
26///  the `R` generic to specify a concrete response type. The resulting type that implements `FromResponse`
27/// must also be generic over the same `R` type.
28pub trait IntoRequest<R = ClientResponse>: Sized {
29    fn into_request(
30        self,
31        req: ClientRequest,
32    ) -> impl Future<Output = Result<R, RequestError>> + 'static;
33}
34
35impl<A, R> IntoRequest<R> for (A,)
36where
37    A: IntoRequest<R> + 'static + Send,
38{
39    fn into_request(
40        self,
41        req: ClientRequest,
42    ) -> impl Future<Output = Result<R, RequestError>> + 'static {
43        A::into_request(self.0, req)
44    }
45}
46
47pub trait FromResponse<R = ClientResponse>: Sized {
48    fn from_response(res: R) -> impl Future<Output = Result<Self, ServerFnError>>;
49}
50
51impl<A> FromResponse for A
52where
53    A: FromResponseParts,
54{
55    fn from_response(res: ClientResponse) -> impl Future<Output = Result<Self, ServerFnError>> {
56        async move {
57            let status = res.status();
58
59            if status.is_success() {
60                let (parts, _body) = res.into_parts();
61                let mut parts = parts;
62                A::from_response_parts(&mut parts)
63            } else {
64                let ErrorPayload::<serde_json::Value> {
65                    message,
66                    code,
67                    data,
68                } = res.json().await?;
69                Err(ServerFnError::ServerError {
70                    message,
71                    code,
72                    details: data,
73                })
74            }
75        }
76    }
77}
78
79impl<A, B> FromResponse for (A, B)
80where
81    A: FromResponseParts,
82    B: FromResponse,
83{
84    fn from_response(res: ClientResponse) -> impl Future<Output = Result<Self, ServerFnError>> {
85        async move {
86            let mut parts = res.make_parts();
87            let a = A::from_response_parts(&mut parts)?;
88            let b = B::from_response(res).await?;
89            Ok((a, b))
90        }
91    }
92}
93
94impl<A, B, C> FromResponse for (A, B, C)
95where
96    A: FromResponseParts,
97    B: FromResponseParts,
98    C: FromResponse,
99{
100    fn from_response(res: ClientResponse) -> impl Future<Output = Result<Self, ServerFnError>> {
101        async move {
102            let mut parts = res.make_parts();
103            let a = A::from_response_parts(&mut parts)?;
104            let b = B::from_response_parts(&mut parts)?;
105            let c = C::from_response(res).await?;
106            Ok((a, b, c))
107        }
108    }
109}
110
111pub trait FromResponseParts
112where
113    Self: Sized,
114{
115    fn from_response_parts(parts: &mut Parts) -> Result<Self, ServerFnError>;
116}
117
118#[cfg(feature = "server")]
119impl<T: Header> FromResponseParts for axum_extra::TypedHeader<T> {
120    fn from_response_parts(parts: &mut Parts) -> Result<Self, ServerFnError> {
121        use headers::HeaderMapExt;
122
123        let t = parts
124            .headers
125            .typed_get::<T>()
126            .ok_or_else(|| ServerFnError::Serialization("Invalid header value".into()))?;
127
128        Ok(axum_extra::TypedHeader(t))
129    }
130}
131
132/*
133todo: make the serverfns return ServerFnRequest which lets us control the future better
134*/
135#[pin_project::pin_project]
136#[must_use = "Requests do nothing unless you `.await` them"]
137pub struct ServerFnRequest<Output> {
138    _phantom: std::marker::PhantomData<Output>,
139    #[pin]
140    fut: Pin<Box<dyn Future<Output = Output> + Send>>,
141}
142
143impl<O> ServerFnRequest<O> {
144    pub fn new(res: impl Future<Output = O> + Send + 'static) -> Self {
145        ServerFnRequest {
146            _phantom: std::marker::PhantomData,
147            fut: Box::pin(res),
148        }
149    }
150}
151
152impl<T, E> std::future::Future for ServerFnRequest<Result<T, E>> {
153    type Output = Result<T, E>;
154
155    fn poll(
156        self: std::pin::Pin<&mut Self>,
157        cx: &mut std::task::Context<'_>,
158    ) -> std::task::Poll<Self::Output> {
159        self.project().fut.poll(cx)
160    }
161}
162
163#[doc(hidden)]
164#[diagnostic::on_unimplemented(
165    message = "The return type of a server function must be `Result<T, E>`",
166    note = "`T` is either `impl IntoResponse` *or* `impl Serialize`",
167    note = "`E` is either `From<ServerFnError> + Serialize`, `dioxus::CapturedError` or `StatusCode`."
168)]
169pub trait AssertIsResult {}
170impl<T, E> AssertIsResult for Result<T, E> {}
171
172#[doc(hidden)]
173pub fn assert_is_result<T: AssertIsResult>() {}
174
175#[diagnostic::on_unimplemented(message = r#"❌ Invalid Arguments to ServerFn ❌
176
177The arguments to the server function must be either:
178
179- a single `impl FromRequest + IntoRequest` argument
180- or multiple `DeserializeOwned` arguments.
181
182Did you forget to implement `IntoRequest` or `Deserialize` for one of the arguments?
183
184`IntoRequest` is a trait that allows payloads to be sent to the server function.
185
186> See https://dioxuslabs.com/learn/0.7/essentials/fullstack/server_functions for more details.
187
188"#)]
189pub trait AssertCanEncode {}
190
191pub struct CantEncode;
192
193pub struct EncodeIsVerified;
194impl AssertCanEncode for EncodeIsVerified {}
195
196#[diagnostic::on_unimplemented(message = r#"❌ Invalid return type from ServerFn ❌
197
198The arguments to the server function must be either:
199
200- a single `impl FromResponse` return type
201- a single `impl Serialize + DeserializedOwned` return type
202
203Did you forget to implement `FromResponse` or `DeserializeOwned` for one of the arguments?
204
205`FromResponse` is a trait that allows payloads to be decoded from the server function response.
206
207> See https://dioxuslabs.com/learn/0.7/essentials/fullstack/server_functions for more details.
208
209"#)]
210pub trait AssertCanDecode {}
211pub struct CantDecode;
212pub struct DecodeIsVerified;
213impl AssertCanDecode for DecodeIsVerified {}
214
215#[doc(hidden)]
216pub fn assert_can_encode(_t: impl AssertCanEncode) {}
217
218#[doc(hidden)]
219pub fn assert_can_decode(_t: impl AssertCanDecode) {}
220
221#[cfg(test)]
222mod test {
223    use http::Extensions;
224
225    use super::*;
226
227    #[derive(Debug)]
228    struct TestFromResponse;
229
230    impl FromResponseParts for TestFromResponse {
231        fn from_response_parts(_parts: &mut Parts) -> Result<Self, ServerFnError> {
232            Ok(Self)
233        }
234    }
235
236    fn build_response(status: u16, body: String) -> ClientResponse {
237        let http_response = http::Response::builder()
238            .status(status)
239            .body(body.into_bytes())
240            .unwrap();
241        let reqwest_response = reqwest::Response::from(http_response);
242        ClientResponse {
243            response: Box::new(reqwest_response),
244            extensions: Extensions::new(),
245        }
246    }
247
248    #[test]
249    fn fromresponseparts_path_decodes_ok_on_2xx() {
250        futures::executor::block_on(async {
251            let response = build_response(200, "".to_string());
252            let result = TestFromResponse::from_response(response).await;
253            assert!(
254                result.is_ok(),
255                "expected Ok(..) for HTTP 200 success case, got: {:?}",
256                result
257            );
258        });
259    }
260
261    #[test]
262    fn fromresponseparts_falls_back_to_request_error_on_unparsable_error_body() {
263        futures::executor::block_on(async {
264            let response = build_response(400, "".to_string());
265            let result = TestFromResponse::from_response(response).await;
266            assert!(result.is_err(), "expected Err(..) for HTTP 400 failed case");
267            let error = result.unwrap_err();
268            assert!(matches!(
269                error,
270                ServerFnError::Request(RequestError::Decode(_))
271            ));
272        });
273    }
274
275    #[test]
276    fn fromresponseparts_parses_error_payload_on_http_error() {
277        futures::executor::block_on(async {
278            let body = r#"{
279                "message": "qwerty",
280                "code": 400
281            }"#;
282            let response = build_response(400, body.to_string());
283            let result = TestFromResponse::from_response(response).await;
284            assert!(result.is_err(), "expected Err(..) for HTTP 400 failed case");
285            let error = result.unwrap_err();
286            assert_eq!(
287                error,
288                ServerFnError::ServerError {
289                    message: "qwerty".to_string(),
290                    code: 400,
291                    details: None
292                }
293            );
294        });
295    }
296}