1use crate::{
39 CantEncode, ClientRequest, ClientResponse, EncodeIsVerified, FromResponse, HttpError,
40 IntoRequest, ServerFnError,
41};
42use axum::response::IntoResponse;
43use axum_core::extract::{FromRequest, Request};
44use bytes::Bytes;
45use dioxus_fullstack_core::RequestError;
46use http::StatusCode;
47use send_wrapper::SendWrapper;
48use serde::Serialize;
49use serde::{de::DeserializeOwned, Deserialize};
50use std::fmt::Display;
51use std::{marker::PhantomData, prelude::rust_2024::Future};
52
53#[doc(hidden)]
54pub struct ServerFnEncoder<In, Out>(PhantomData<fn() -> (In, Out)>);
55impl<In, Out> ServerFnEncoder<In, Out> {
56 #[doc(hidden)]
57 pub fn new() -> Self {
58 ServerFnEncoder(PhantomData)
59 }
60}
61
62#[doc(hidden)]
63pub struct ServerFnDecoder<Out>(PhantomData<fn() -> Out>);
64impl<Out> ServerFnDecoder<Out> {
65 #[doc(hidden)]
66 pub fn new() -> Self {
67 ServerFnDecoder(PhantomData)
68 }
69}
70
71#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
79pub enum RestEndpointPayload<T, E> {
80 #[serde(rename = "success")]
81 Success(T),
82
83 #[serde(rename = "error")]
84 Error(ErrorPayload<E>),
85}
86
87#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
89pub struct ErrorPayload<E> {
90 message: String,
91
92 code: u16,
93
94 #[serde(skip_serializing_if = "Option::is_none")]
95 data: Option<E>,
96}
97
98pub fn reqwest_response_to_serverfn_err(err: reqwest::Error) -> ServerFnError {
102 ServerFnError::Request(reqwest_error_to_request_error(err))
103}
104
105pub fn reqwest_error_to_request_error(err: reqwest::Error) -> RequestError {
106 let message = err.to_string();
107 if err.is_timeout() {
108 RequestError::Timeout(message)
109 } else if err.is_request() {
110 RequestError::Request(message)
111 } else if err.is_body() {
112 RequestError::Body(message)
113 } else if err.is_decode() {
114 RequestError::Decode(message)
115 } else if err.is_redirect() {
116 RequestError::Redirect(message)
117 } else if let Some(status) = err.status() {
118 RequestError::Status(message, status.as_u16())
119 } else {
120 #[cfg(not(target_arch = "wasm32"))]
121 {
122 if err.is_connect() {
123 RequestError::Connect(message)
124 } else {
125 RequestError::Request(message)
126 }
127 }
128
129 #[cfg(target_arch = "wasm32")]
130 {
131 RequestError::Request(message)
132 }
133 }
134}
135
136pub use req_to::*;
137pub mod req_to {
138 use super::*;
139
140 pub trait EncodeRequest<In, Out, R> {
141 type VerifyEncode;
142 fn fetch_client(
143 &self,
144 ctx: ClientRequest,
145 data: In,
146 map: fn(In) -> Out,
147 ) -> impl Future<Output = Result<R, RequestError>> + 'static;
148 fn verify_can_serialize(&self) -> Self::VerifyEncode;
149 }
150
151 impl<T, O> EncodeRequest<T, O, ClientResponse> for &&&&&&&&&&ServerFnEncoder<T, O>
153 where
154 T: DeserializeOwned + Serialize + 'static,
155 {
156 type VerifyEncode = EncodeIsVerified;
157 fn fetch_client(
158 &self,
159 ctx: ClientRequest,
160 data: T,
161 _map: fn(T) -> O,
162 ) -> impl Future<Output = Result<ClientResponse, RequestError>> + 'static {
163 async move { ctx.send_json(&data).await }
164 }
165
166 fn verify_can_serialize(&self) -> Self::VerifyEncode {
167 EncodeIsVerified
168 }
169 }
170
171 impl<T, O, R> EncodeRequest<T, O, R> for &&&&&&&&&ServerFnEncoder<T, O>
173 where
174 T: 'static,
175 O: IntoRequest<R>,
176 {
177 type VerifyEncode = EncodeIsVerified;
178 fn fetch_client(
179 &self,
180 ctx: ClientRequest,
181 data: T,
182 map: fn(T) -> O,
183 ) -> impl Future<Output = Result<R, RequestError>> + 'static {
184 O::into_request(map(data), ctx)
185 }
186
187 fn verify_can_serialize(&self) -> Self::VerifyEncode {
188 EncodeIsVerified
189 }
190 }
191
192 impl<T, O> EncodeRequest<T, O, ClientResponse> for &ServerFnEncoder<T, O>
194 where
195 T: 'static,
196 {
197 type VerifyEncode = CantEncode;
198 #[allow(clippy::manual_async_fn)]
199 fn fetch_client(
200 &self,
201 _ctx: ClientRequest,
202 _data: T,
203 _map: fn(T) -> O,
204 ) -> impl Future<Output = Result<ClientResponse, RequestError>> + 'static {
205 async move { unimplemented!() }
206 }
207
208 fn verify_can_serialize(&self) -> Self::VerifyEncode {
209 CantEncode
210 }
211 }
212}
213
214pub use decode_ok::*;
215mod decode_ok {
216
217 use crate::{CantDecode, DecodeIsVerified};
218
219 use super::*;
220
221 pub trait RequestDecodeResult<T, R> {
227 type VerifyDecode;
228 fn decode_client_response(
229 &self,
230 res: Result<R, RequestError>,
231 ) -> impl Future<Output = Result<Result<T, ServerFnError>, RequestError>> + Send;
232 fn verify_can_deserialize(&self) -> Self::VerifyDecode;
233 }
234
235 impl<T: FromResponse<R>, E, R> RequestDecodeResult<T, R> for &&&ServerFnDecoder<Result<T, E>> {
236 type VerifyDecode = DecodeIsVerified;
237 fn decode_client_response(
238 &self,
239 res: Result<R, RequestError>,
240 ) -> impl Future<Output = Result<Result<T, ServerFnError>, RequestError>> + Send {
241 SendWrapper::new(async move {
242 match res {
243 Err(err) => Err(err),
244 Ok(res) => Ok(T::from_response(res).await),
245 }
246 })
247 }
248 fn verify_can_deserialize(&self) -> Self::VerifyDecode {
249 DecodeIsVerified
250 }
251 }
252
253 impl<T: DeserializeOwned, E> RequestDecodeResult<T, ClientResponse>
254 for &&ServerFnDecoder<Result<T, E>>
255 {
256 type VerifyDecode = DecodeIsVerified;
257 fn decode_client_response(
258 &self,
259 res: Result<ClientResponse, RequestError>,
260 ) -> impl Future<Output = Result<Result<T, ServerFnError>, RequestError>> + Send {
261 SendWrapper::new(async move {
262 match res {
263 Err(err) => Err(err),
264 Ok(res) => {
265 let status = res.status();
266
267 let bytes = res.bytes().await.unwrap();
268 let as_bytes = if bytes.is_empty() {
269 b"null".as_slice()
270 } else {
271 &bytes
272 };
273
274 let res = if status.is_success() {
275 serde_json::from_slice::<T>(as_bytes)
276 .map(RestEndpointPayload::Success)
277 .map_err(|e| ServerFnError::Deserialization(e.to_string()))
278 } else {
279 match serde_json::from_slice::<ErrorPayload<serde_json::Value>>(
280 as_bytes,
281 ) {
282 Ok(res) => Ok(RestEndpointPayload::Error(ErrorPayload {
283 message: res.message,
284 code: res.code,
285 data: res.data,
286 })),
287 Err(err) => {
288 if let Ok(text) = String::from_utf8(as_bytes.to_vec()) {
289 Ok(RestEndpointPayload::Error(ErrorPayload {
290 message: format!("HTTP {}: {}", status.as_u16(), text),
291 code: status.as_u16(),
292 data: None,
293 }))
294 } else {
295 Err(ServerFnError::Deserialization(err.to_string()))
296 }
297 }
298 }
299 };
300
301 match res {
302 Ok(RestEndpointPayload::Success(t)) => Ok(Ok(t)),
303 Ok(RestEndpointPayload::Error(err)) => {
304 Ok(Err(ServerFnError::ServerError {
305 message: err.message,
306 details: err.data,
307 code: err.code,
308 }))
309 }
310 Err(e) => Ok(Err(e)),
311 }
312 }
313 }
314 })
315 }
316 fn verify_can_deserialize(&self) -> Self::VerifyDecode {
317 DecodeIsVerified
318 }
319 }
320
321 impl<T, R, E> RequestDecodeResult<T, R> for &ServerFnDecoder<Result<T, E>> {
322 type VerifyDecode = CantDecode;
323
324 fn decode_client_response(
325 &self,
326 _res: Result<R, RequestError>,
327 ) -> impl Future<Output = Result<Result<T, ServerFnError>, RequestError>> + Send {
328 async move { unimplemented!() }
329 }
330
331 fn verify_can_deserialize(&self) -> Self::VerifyDecode {
332 CantDecode
333 }
334 }
335
336 pub trait RequestDecodeErr<T, E> {
337 fn decode_client_err(
338 &self,
339 res: Result<Result<T, ServerFnError>, RequestError>,
340 ) -> impl Future<Output = Result<T, E>> + Send;
341 }
342
343 impl<T, E> RequestDecodeErr<T, E> for &&&ServerFnDecoder<Result<T, E>>
344 where
345 E: From<ServerFnError> + DeserializeOwned + Serialize,
346 {
347 fn decode_client_err(
348 &self,
349 res: Result<Result<T, ServerFnError>, RequestError>,
350 ) -> impl Future<Output = Result<T, E>> + Send {
351 SendWrapper::new(async move {
352 match res {
353 Ok(Ok(res)) => Ok(res),
354 Ok(Err(e)) => match e {
355 ServerFnError::ServerError {
356 details,
357 message,
358 code,
359 } => {
360 match details {
363 Some(details) => match serde_json::from_value::<E>(details) {
364 Ok(res) => Err(res),
365 Err(err) => Err(E::from(ServerFnError::Deserialization(
366 err.to_string(),
367 ))),
368 },
369 None => Err(E::from(ServerFnError::ServerError {
370 message,
371 details: None,
372 code,
373 })),
374 }
375 }
376 err => Err(err.into()),
377 },
378 Err(err) => Err(ServerFnError::from(err).into()),
381 }
382 })
383 }
384 }
385
386 impl<T> RequestDecodeErr<T, anyhow::Error> for &&ServerFnDecoder<Result<T, anyhow::Error>> {
391 fn decode_client_err(
392 &self,
393 res: Result<Result<T, ServerFnError>, RequestError>,
394 ) -> impl Future<Output = Result<T, anyhow::Error>> + Send {
395 SendWrapper::new(async move {
396 match res {
397 Ok(Ok(res)) => Ok(res),
398 Ok(Err(e)) => Err(anyhow::Error::from(e)),
399 Err(err) => Err(anyhow::Error::from(err)),
400 }
401 })
402 }
403 }
404
405 impl<T> RequestDecodeErr<T, StatusCode> for &ServerFnDecoder<Result<T, StatusCode>> {
407 fn decode_client_err(
408 &self,
409 res: Result<Result<T, ServerFnError>, RequestError>,
410 ) -> impl Future<Output = Result<T, StatusCode>> + Send {
411 SendWrapper::new(async move {
412 match res {
413 Ok(Ok(res)) => Ok(res),
414
415 Ok(Err(e)) => match e {
417 ServerFnError::Request(error) => {
418 Err(StatusCode::from_u16(error.status_code().unwrap_or(500))
419 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
420 }
421
422 ServerFnError::ServerError {
423 message: _message,
424 details: _details,
425 code,
426 } => {
427 Err(StatusCode::from_u16(code)
428 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR))
429 }
430
431 ServerFnError::Registration(_) | ServerFnError::MiddlewareError(_) => {
432 Err(StatusCode::INTERNAL_SERVER_ERROR)
433 }
434
435 ServerFnError::Deserialization(_)
436 | ServerFnError::Serialization(_)
437 | ServerFnError::Args(_)
438 | ServerFnError::MissingArg(_)
439 | ServerFnError::StreamError(_) => Err(StatusCode::UNPROCESSABLE_ENTITY),
440
441 ServerFnError::UnsupportedRequestMethod(_) => {
442 Err(StatusCode::METHOD_NOT_ALLOWED)
443 }
444
445 ServerFnError::Response(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
446 },
447
448 Err(reqwest_err) => {
450 let code = reqwest_err
451 .status()
452 .unwrap_or(StatusCode::SERVICE_UNAVAILABLE);
453 Err(code)
454 }
455 }
456 })
457 }
458 }
459
460 impl<T> RequestDecodeErr<T, HttpError> for &ServerFnDecoder<Result<T, HttpError>> {
461 fn decode_client_err(
462 &self,
463 res: Result<Result<T, ServerFnError>, RequestError>,
464 ) -> impl Future<Output = Result<T, HttpError>> + Send {
465 SendWrapper::new(async move {
466 match res {
467 Ok(Ok(res)) => Ok(res),
468 Ok(Err(res)) => match res {
469 ServerFnError::ServerError { message, code, .. } => Err(HttpError {
470 status: StatusCode::from_u16(code)
471 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
472 message: Some(message),
473 }),
474 _ => HttpError::internal_server_error("Internal Server Error"),
475 },
476 Err(err) => Err(HttpError::new(
477 err.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
478 err.to_string(),
479 )),
480 }
481 })
482 }
483 }
484}
485
486pub use req_from::*;
487pub mod req_from {
488 use super::*;
489 use axum::{extract::FromRequestParts, response::Response};
490 use dioxus_fullstack_core::FullstackContext;
491
492 pub trait ExtractRequest<In, Out, H, M = ()> {
493 fn extract_axum(
494 &self,
495 state: FullstackContext,
496 request: Request,
497 map: fn(In) -> Out,
498 ) -> impl Future<Output = Result<(Out, H), Response>> + 'static;
499 }
500
501 impl<In, M, H> ExtractRequest<In, (), H, M> for &&&&&&&&&&&ServerFnEncoder<In, ()>
504 where
505 H: FromRequest<FullstackContext, M> + 'static,
506 {
507 fn extract_axum(
508 &self,
509 state: FullstackContext,
510 request: Request,
511 _map: fn(In) -> (),
512 ) -> impl Future<Output = Result<((), H), Response>> + 'static {
513 async move {
514 H::from_request(request, &state)
515 .await
516 .map_err(|e| e.into_response())
517 .map(|out| ((), out))
518 }
519 }
520 }
521
522 impl<In, Out, H> ExtractRequest<In, Out, H> for &&&&&&&&&&ServerFnEncoder<In, Out>
524 where
525 In: DeserializeOwned + 'static,
526 Out: 'static,
527 H: FromRequestParts<FullstackContext>,
528 {
529 fn extract_axum(
530 &self,
531 _state: FullstackContext,
532 request: Request,
533 map: fn(In) -> Out,
534 ) -> impl Future<Output = Result<(Out, H), Response>> + 'static {
535 async move {
536 let (mut parts, body) = request.into_parts();
537 let headers = H::from_request_parts(&mut parts, &_state)
538 .await
539 .map_err(|e| e.into_response())?;
540
541 let request = Request::from_parts(parts, body);
542 let bytes = Bytes::from_request(request, &()).await.unwrap();
543 let as_str = String::from_utf8_lossy(&bytes);
544
545 let bytes = if as_str.is_empty() {
546 "{}".as_bytes()
547 } else {
548 &bytes
549 };
550
551 let out = serde_json::from_slice::<In>(bytes)
552 .map(map)
553 .map_err(|e| ServerFnError::from(e).into_response())?;
554
555 Ok((out, headers))
556 }
557 }
558 }
559
560 impl<In, Out, M, H> ExtractRequest<In, Out, H, M> for &&&&&&&&&ServerFnEncoder<In, Out>
562 where
563 Out: FromRequest<FullstackContext, M> + 'static,
564 H: FromRequestParts<FullstackContext>,
565 {
566 fn extract_axum(
567 &self,
568 state: FullstackContext,
569 request: Request,
570 _map: fn(In) -> Out,
571 ) -> impl Future<Output = Result<(Out, H), Response>> + 'static {
572 async move {
573 let (mut parts, body) = request.into_parts();
574 let headers = H::from_request_parts(&mut parts, &state)
575 .await
576 .map_err(|e| e.into_response())?;
577
578 let request = Request::from_parts(parts, body);
579
580 let res = Out::from_request(request, &state)
581 .await
582 .map_err(|e| e.into_response());
583
584 res.map(|out| (out, headers))
585 }
586 }
587 }
588}
589
590pub use resp::*;
591mod resp {
592 use crate::HttpError;
593
594 use super::*;
595 use axum::response::Response;
596 use dioxus_core::CapturedError;
597 use http::HeaderValue;
598
599 pub trait MakeAxumResponse<T, E, R> {
607 fn make_axum_response(self, result: Result<T, E>) -> Result<Response, E>;
608 }
609
610 impl<T, E, R> MakeAxumResponse<T, E, R> for &&&&ServerFnDecoder<Result<T, E>>
613 where
614 T: FromResponse<R> + IntoResponse,
615 {
616 fn make_axum_response(self, result: Result<T, E>) -> Result<Response, E> {
617 result.map(|v| v.into_response())
618 }
619 }
620
621 impl<T, E> MakeAxumResponse<T, E, ()> for &&&ServerFnDecoder<Result<T, E>>
624 where
625 T: DeserializeOwned + Serialize,
626 {
627 fn make_axum_response(self, result: Result<T, E>) -> Result<Response, E> {
628 match result {
629 Ok(res) => {
630 let body = serde_json::to_string(&res).unwrap();
631 let mut resp = Response::new(body.into());
632 resp.headers_mut().insert(
633 http::header::CONTENT_TYPE,
634 HeaderValue::from_static("application/json"),
635 );
636 *resp.status_mut() = StatusCode::OK;
637 Ok(resp)
638 }
639 Err(err) => Err(err),
640 }
641 }
642 }
643
644 #[allow(clippy::result_large_err)]
645 pub trait MakeAxumError<E> {
646 fn make_axum_error(self, result: Result<Response, E>) -> Response;
647 }
648
649 pub trait AsStatusCode {
651 fn as_status_code(&self) -> StatusCode;
652 }
653
654 impl AsStatusCode for ServerFnError {
655 fn as_status_code(&self) -> StatusCode {
656 match self {
657 Self::ServerError { code, .. } => {
658 StatusCode::from_u16(*code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
659 }
660 _ => StatusCode::INTERNAL_SERVER_ERROR,
661 }
662 }
663 }
664
665 impl<T, E> MakeAxumError<E> for &&&ServerFnDecoder<Result<T, E>>
666 where
667 E: AsStatusCode + From<ServerFnError> + Serialize + DeserializeOwned + Display,
668 {
669 fn make_axum_error(self, result: Result<Response, E>) -> Response {
670 match result {
671 Ok(res) => res,
672 Err(err) => {
673 let status_code = err.as_status_code();
674 let err = ErrorPayload {
675 code: status_code.as_u16(),
676 message: err.to_string(),
677 data: Some(err),
678 };
679 let body = serde_json::to_string(&err).unwrap();
680 let mut resp = Response::new(body.into());
681 resp.headers_mut().insert(
682 http::header::CONTENT_TYPE,
683 HeaderValue::from_static("application/json"),
684 );
685 *resp.status_mut() = status_code;
686 resp
687 }
688 }
689 }
690 }
691
692 impl<T> MakeAxumError<CapturedError> for &&ServerFnDecoder<Result<T, CapturedError>> {
693 fn make_axum_error(self, result: Result<Response, CapturedError>) -> Response {
694 match result {
695 Ok(res) => res,
696
697 Err(errr) if errr._strong_count() == 1 => {
699 let err = errr.into_inner().unwrap();
700 <&&ServerFnDecoder<Result<T, anyhow::Error>> as MakeAxumError<anyhow::Error>>::make_axum_error(
701 &&ServerFnDecoder::new(),
702 Err(err),
703 )
704 }
705
706 Err(errr) => {
707 let payload = match errr.downcast_ref::<ServerFnError>() {
710 Some(ServerFnError::ServerError {
711 message,
712 code,
713 details,
714 }) => ErrorPayload {
715 message: message.clone(),
716 code: *code,
717 data: details.clone(),
718 },
719 Some(other) => ErrorPayload {
720 message: other.to_string(),
721 code: 500,
722 data: None,
723 },
724 None => match errr.downcast_ref::<HttpError>() {
725 Some(http_err) => ErrorPayload {
726 message: http_err
727 .message
728 .clone()
729 .unwrap_or_else(|| http_err.status.to_string()),
730 code: http_err.status.as_u16(),
731 data: None,
732 },
733 None => ErrorPayload {
734 code: 500,
735 message: errr.to_string(),
736 data: None,
737 },
738 },
739 };
740
741 let body = serde_json::to_string(&payload).unwrap();
742 let mut resp = Response::new(body.into());
743 resp.headers_mut().insert(
744 http::header::CONTENT_TYPE,
745 HeaderValue::from_static("application/json"),
746 );
747 *resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
748 resp
749 }
750 }
751 }
752 }
753
754 impl<T> MakeAxumError<anyhow::Error> for &&ServerFnDecoder<Result<T, anyhow::Error>> {
755 fn make_axum_error(self, result: Result<Response, anyhow::Error>) -> Response {
756 match result {
757 Ok(res) => res,
758 Err(errr) => {
759 let payload = match errr.downcast::<ServerFnError>() {
762 Ok(ServerFnError::ServerError {
763 message,
764 code,
765 details,
766 }) => ErrorPayload {
767 message,
768 code,
769 data: details,
770 },
771 Ok(other) => ErrorPayload {
772 message: other.to_string(),
773 code: 500,
774 data: None,
775 },
776 Err(err) => match err.downcast::<HttpError>() {
777 Ok(http_err) => ErrorPayload {
778 message: http_err
779 .message
780 .unwrap_or_else(|| http_err.status.to_string()),
781 code: http_err.status.as_u16(),
782 data: None,
783 },
784 Err(err) => ErrorPayload {
785 code: 500,
786 message: err.to_string(),
787 data: None,
788 },
789 },
790 };
791
792 let body = serde_json::to_string(&payload).unwrap();
793 let mut resp = Response::new(body.into());
794 resp.headers_mut().insert(
795 http::header::CONTENT_TYPE,
796 HeaderValue::from_static("application/json"),
797 );
798 *resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
799 resp
800 }
801 }
802 }
803 }
804
805 impl<T> MakeAxumError<StatusCode> for &&ServerFnDecoder<Result<T, StatusCode>> {
806 fn make_axum_error(self, result: Result<Response, StatusCode>) -> Response {
807 match result {
808 Ok(resp) => resp,
809 Err(status) => {
810 let body = serde_json::to_string(&ErrorPayload::<()> {
811 code: status.as_u16(),
812 message: status.to_string(),
813 data: None,
814 })
815 .unwrap();
816 let mut resp = Response::new(body.into());
817 resp.headers_mut().insert(
818 http::header::CONTENT_TYPE,
819 HeaderValue::from_static("application/json"),
820 );
821 *resp.status_mut() = status;
822 resp
823 }
824 }
825 }
826 }
827
828 impl<T> MakeAxumError<HttpError> for &ServerFnDecoder<Result<T, HttpError>> {
829 fn make_axum_error(self, result: Result<Response, HttpError>) -> Response {
830 match result {
831 Ok(resp) => resp,
832 Err(http_err) => {
833 let body = serde_json::to_string(&ErrorPayload::<()> {
834 code: http_err.status.as_u16(),
835 message: http_err
836 .message
837 .unwrap_or_else(|| http_err.status.to_string()),
838 data: None,
839 })
840 .unwrap();
841 let mut resp = Response::new(body.into());
842 resp.headers_mut().insert(
843 http::header::CONTENT_TYPE,
844 HeaderValue::from_static("application/json"),
845 );
846 *resp.status_mut() = http_err.status;
847 resp
848 }
849 }
850 }
851 }
852}