dioxus_fullstack/
client.rs

1#![allow(unreachable_code)]
2
3use crate::{reqwest_error_to_request_error, StreamingError};
4use bytes::Bytes;
5use dioxus_fullstack_core::RequestError;
6use futures::Stream;
7use futures::{TryFutureExt, TryStreamExt};
8use headers::{ContentType, Header};
9use http::{response::Parts, Extensions, HeaderMap, HeaderName, HeaderValue, Method, StatusCode};
10use send_wrapper::SendWrapper;
11use serde::{de::DeserializeOwned, Serialize};
12use std::sync::{LazyLock, Mutex, OnceLock};
13use std::{fmt::Display, pin::Pin, prelude::rust_2024::Future};
14use url::Url;
15
16pub static GLOBAL_REQUEST_CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
17
18pub type ClientResult = Result<ClientResponse, RequestError>;
19
20pub struct ClientRequest {
21    pub url: Url,
22    pub headers: HeaderMap,
23    pub method: Method,
24    pub extensions: Extensions,
25}
26
27impl ClientRequest {
28    /// Create a new ClientRequest with the given method, url path, and query parameters.
29    pub fn new(method: http::Method, path: String, params: &impl Serialize) -> Self {
30        Self::fetch_inner(method, path, serde_qs::to_string(params).unwrap())
31    }
32
33    // Shrink monomorphization bloat by moving this to its own function
34    fn fetch_inner(method: http::Method, path: String, query: String) -> ClientRequest {
35        // On wasm, this doesn't matter since we always use relative URLs when making requests anyways
36        let mut server_url = get_server_url();
37
38        if server_url.is_empty() {
39            server_url = "http://this.is.not.a.real.url:9000";
40        }
41
42        let url = format!(
43            "{server_url}{path}{params}",
44            params = if query.is_empty() {
45                "".to_string()
46            } else {
47                format!("?{}", query)
48            }
49        )
50        .parse()
51        .unwrap();
52
53        let headers = get_request_headers();
54
55        ClientRequest {
56            method,
57            url,
58            headers,
59            extensions: Extensions::new(),
60        }
61    }
62
63    /// Get the HTTP method of this Request.
64    pub fn method(&self) -> &Method {
65        &self.method
66    }
67
68    pub fn url(&self) -> &Url {
69        &self.url
70    }
71
72    /// Extend the query parameters of this request with the given serialzable struct.
73    ///
74    /// This will use `serde_qs` to serialize the struct into query parameters. `serde_qs` has various
75    /// restrictions - make sure to read its documentation!
76    pub fn extend_query(mut self, query: &impl Serialize) -> Self {
77        let old_query = self.url.query().unwrap_or("");
78        let new_query = serde_qs::to_string(query).unwrap();
79        let combined_query = format!(
80            "{}{}{}",
81            old_query,
82            if old_query.is_empty() { "" } else { "&" },
83            new_query
84        );
85        self.url.set_query(Some(&combined_query));
86        self
87    }
88
89    /// Add a `Header` to this Request.
90    pub fn header(
91        mut self,
92        name: impl TryInto<HeaderName, Error = impl Display>,
93        value: impl TryInto<HeaderValue, Error = impl Display>,
94    ) -> Result<Self, RequestError> {
95        self.headers.append(
96            name.try_into()
97                .map_err(|d| RequestError::Builder(d.to_string()))?,
98            value
99                .try_into()
100                .map_err(|d| RequestError::Builder(d.to_string()))?,
101        );
102        Ok(self)
103    }
104
105    /// Add a `Header` to this Request.
106    pub fn typed_header<H: Header>(mut self, header: H) -> Self {
107        let mut headers = vec![];
108        header.encode(&mut headers);
109        for header in headers {
110            self.headers.append(H::name(), header);
111        }
112        self
113    }
114
115    /// Creates a new reqwest client with cookies set
116    pub fn new_reqwest_client() -> reqwest::Client {
117        #[allow(unused_mut)]
118        let mut client = reqwest::Client::builder();
119
120        #[cfg(not(target_arch = "wasm32"))]
121        {
122            use std::sync::Arc;
123            use std::sync::LazyLock;
124
125            static COOKIES: LazyLock<Arc<reqwest::cookie::Jar>> =
126                LazyLock::new(|| Arc::new(reqwest::cookie::Jar::default()));
127
128            client = client.cookie_store(true).cookie_provider(COOKIES.clone());
129        }
130
131        client.build().unwrap()
132    }
133
134    /// Creates a new reqwest request builder with the method, url, and headers set from this ClientRequest
135    ///
136    /// Using this method attaches `X-Request-Client: dioxus` header to the request.
137    pub fn new_reqwest_request(&self) -> reqwest::RequestBuilder {
138        let client = GLOBAL_REQUEST_CLIENT.get_or_init(Self::new_reqwest_client);
139
140        let mut req = client
141            .request(self.method.clone(), self.url.clone())
142            .header("X-Request-Client", "dioxus");
143
144        for (key, value) in self.headers.iter() {
145            req = req.header(key, value);
146        }
147
148        req
149    }
150
151    /// Using this method attaches `X-Request-Client-Dioxus` header to the request.
152    #[cfg(feature = "web")]
153    pub fn new_gloo_request(&self) -> gloo_net::http::RequestBuilder {
154        let mut builder = gloo_net::http::RequestBuilder::new(
155            format!(
156                "{path}{query_string}",
157                path = self.url.path(),
158                query_string = self
159                    .url
160                    .query()
161                    .map(|query| format!("?{query}"))
162                    .unwrap_or_default()
163            )
164            .as_str(),
165        )
166        .header("X-Request-Client", "dioxus")
167        .method(self.method.clone());
168
169        for (key, value) in self.headers.iter() {
170            let value = match value.to_str() {
171                Ok(v) => v,
172                Err(er) => {
173                    tracing::error!("Error converting header {key} value: {}", er);
174                    continue;
175                }
176            };
177
178            builder = builder.header(key.as_str(), value);
179        }
180
181        builder
182    }
183
184    /// Sends the request with multipart/form-data body constructed from the given FormData.
185    #[cfg(not(target_arch = "wasm32"))]
186    pub async fn send_multipart(
187        self,
188        form: &dioxus_html::FormData,
189    ) -> Result<ClientResponse, RequestError> {
190        let mut outgoing = reqwest::multipart::Form::new();
191
192        for (key, value) in form.values() {
193            match value {
194                dioxus_html::FormValue::Text(text) => {
195                    outgoing = outgoing.text(key.to_string(), text.to_string());
196                }
197                dioxus_html::FormValue::File(Some(file_data)) => {
198                    outgoing = outgoing
199                        .file(key.to_string(), file_data.path())
200                        .await
201                        .map_err(|e| {
202                            RequestError::Builder(format!(
203                                "Failed to add file to multipart form: {e}",
204                            ))
205                        })?;
206                }
207                dioxus_html::FormValue::File(None) => {
208                    // No file was selected for this input, so we skip it.
209                    outgoing = outgoing.part(key.to_string(), reqwest::multipart::Part::bytes(b""));
210                }
211            }
212        }
213
214        let res = self
215            .new_reqwest_request()
216            .multipart(outgoing)
217            .send()
218            .await
219            .map_err(reqwest_error_to_request_error)?;
220
221        Ok(ClientResponse {
222            response: Box::new(res),
223            extensions: self.extensions,
224        })
225    }
226
227    pub async fn send_form(self, data: &impl Serialize) -> Result<ClientResponse, RequestError> {
228        // For GET and HEAD requests, we encode the form data as query parameters.
229        // For other request methods, we encode the form data as the request body.
230        if matches!(*self.method(), Method::GET | Method::HEAD) {
231            return self.extend_query(data).send_empty_body().await;
232        }
233
234        let body =
235            serde_urlencoded::to_string(data).map_err(|err| RequestError::Body(err.to_string()))?;
236
237        self.typed_header(ContentType::form_url_encoded())
238            .send_raw_bytes(body)
239            .await
240    }
241
242    /// Sends the request with an empty body.
243    pub async fn send_empty_body(self) -> Result<ClientResponse, RequestError> {
244        #[cfg(feature = "web")]
245        if cfg!(target_arch = "wasm32") {
246            return self.send_js_value(wasm_bindgen::JsValue::UNDEFINED).await;
247        }
248
249        #[cfg(not(target_arch = "wasm32"))]
250        {
251            let res = self
252                .new_reqwest_request()
253                .send()
254                .await
255                .map_err(reqwest_error_to_request_error)?;
256
257            return Ok(ClientResponse {
258                response: Box::new(res),
259                extensions: self.extensions,
260            });
261        }
262
263        unimplemented!()
264    }
265
266    pub async fn send_raw_bytes(
267        self,
268        bytes: impl Into<Bytes>,
269    ) -> Result<ClientResponse, RequestError> {
270        #[cfg(feature = "web")]
271        if cfg!(target_arch = "wasm32") {
272            let bytes = bytes.into();
273            let uint_8_array = js_sys::Uint8Array::from(&bytes[..]);
274            return self.send_js_value(uint_8_array.into()).await;
275        }
276
277        #[cfg(not(target_arch = "wasm32"))]
278        {
279            let res = self
280                .new_reqwest_request()
281                .body(bytes.into())
282                .send()
283                .await
284                .map_err(reqwest_error_to_request_error)?;
285
286            return Ok(ClientResponse {
287                response: Box::new(res),
288                extensions: self.extensions,
289            });
290        }
291
292        unimplemented!()
293    }
294
295    /// Sends text data with the `text/plain; charset=utf-8` content type.
296    pub async fn send_text(
297        self,
298        text: impl Into<String> + Into<Bytes>,
299    ) -> Result<ClientResponse, RequestError> {
300        self.typed_header(ContentType::text_utf8())
301            .send_raw_bytes(text)
302            .await
303    }
304
305    /// Sends JSON data with the `application/json` content type.
306    pub async fn send_json(self, json: &impl Serialize) -> Result<ClientResponse, RequestError> {
307        let bytes =
308            serde_json::to_vec(json).map_err(|e| RequestError::Serialization(e.to_string()))?;
309
310        if bytes.is_empty() || bytes == b"{}" || bytes == b"null" {
311            return self.send_empty_body().await;
312        }
313
314        self.typed_header(ContentType::json())
315            .send_raw_bytes(bytes)
316            .await
317    }
318
319    pub async fn send_body_stream(
320        self,
321        stream: impl Stream<Item = Result<Bytes, StreamingError>> + Send + 'static,
322    ) -> Result<ClientResponse, RequestError> {
323        #[cfg(not(target_arch = "wasm32"))]
324        {
325            let res = self
326                .new_reqwest_request()
327                .body(reqwest::Body::wrap_stream(stream))
328                .send()
329                .await
330                .map_err(reqwest_error_to_request_error)?;
331
332            return Ok(ClientResponse {
333                response: Box::new(res),
334                extensions: self.extensions,
335            });
336        }
337
338        // On the web, we have to buffer the entire stream into a Blob before sending it,
339        // since the Fetch API doesn't support streaming request bodies on browsers yet.
340        #[cfg(feature = "web")]
341        {
342            use wasm_bindgen::JsValue;
343
344            let stream: Vec<Bytes> = stream.try_collect().await.map_err(|e| {
345                RequestError::Request(format!("Error collecting stream for request body: {}", e))
346            })?;
347
348            let uint_8_array =
349                js_sys::Uint8Array::new_with_length(stream.iter().map(|b| b.len() as u32).sum());
350
351            let mut offset = 0;
352            for chunk in stream {
353                uint_8_array.set(&js_sys::Uint8Array::from(&chunk[..]), offset);
354                offset += chunk.len() as u32;
355            }
356
357            return self.send_js_value(JsValue::from(uint_8_array)).await;
358        }
359
360        unimplemented!()
361    }
362
363    #[cfg(feature = "web")]
364    pub async fn send_js_value(
365        self,
366        value: wasm_bindgen::JsValue,
367    ) -> Result<ClientResponse, RequestError> {
368        use std::str::FromStr;
369
370        let inner = self
371            .new_gloo_request()
372            .body(value)
373            .map_err(|e| RequestError::Request(e.to_string()))?
374            .send()
375            .await
376            .map_err(|e| RequestError::Request(e.to_string()))?;
377
378        let status = inner.status();
379        let url = inner
380            .url()
381            .parse()
382            .map_err(|e| RequestError::Request(format!("Error parsing response URL: {}", e)))?;
383
384        let headers = {
385            let mut map = HeaderMap::new();
386            for (key, value) in inner.headers().entries() {
387                if let Ok(header_value) = http::HeaderValue::from_str(&value) {
388                    let header = HeaderName::from_str(&key).unwrap();
389                    map.append(header, header_value);
390                }
391            }
392            map
393        };
394
395        let content_length = headers
396            .get(http::header::CONTENT_LENGTH)
397            .and_then(|val| val.to_str().ok())
398            .and_then(|s| s.parse::<u64>().ok());
399
400        let status = http::StatusCode::from_u16(status).unwrap_or(http::StatusCode::OK);
401
402        Ok(ClientResponse {
403            extensions: self.extensions,
404            response: Box::new(browser::WrappedGlooResponse {
405                inner,
406                headers,
407                status,
408                url,
409                content_length,
410            }),
411        })
412    }
413}
414
415// On wasm reqwest not being send/sync gets annoying, but it's not relevant since wasm is single-threaded
416unsafe impl Send for ClientRequest {}
417unsafe impl Sync for ClientRequest {}
418
419/// A wrapper type over the platform's HTTP response type.
420///
421/// This abstracts over the inner `reqwest::Response` type and provides the original request
422/// and a way to store state associated with the response.
423///
424/// On the web, it uses `web_sys::Response` instead of `reqwest::Response` to avoid pulling in
425/// the entire `reqwest` crate and to support native browser APIs.
426pub struct ClientResponse {
427    pub(crate) response: Box<dyn ClientResponseDriver>,
428    pub(crate) extensions: Extensions,
429}
430
431impl ClientResponse {
432    pub fn status(&self) -> StatusCode {
433        self.response.status()
434    }
435
436    pub fn headers(&self) -> &HeaderMap {
437        self.response.headers()
438    }
439
440    pub fn url(&self) -> &Url {
441        self.response.url()
442    }
443
444    pub fn content_length(&self) -> Option<u64> {
445        self.response.content_length()
446    }
447
448    pub async fn bytes(self) -> Result<Bytes, RequestError> {
449        self.response.bytes().await
450    }
451
452    pub fn bytes_stream(
453        self,
454    ) -> impl futures_util::Stream<Item = Result<Bytes, StreamingError>> + 'static + Unpin + Send
455    {
456        self.response.bytes_stream()
457    }
458
459    pub fn extensions(&self) -> &Extensions {
460        &self.extensions
461    }
462
463    pub fn extensions_mut(&mut self) -> &mut Extensions {
464        &mut self.extensions
465    }
466
467    pub async fn json<T: DeserializeOwned>(self) -> Result<T, RequestError> {
468        serde_json::from_slice(&self.bytes().await?)
469            .map_err(|e| RequestError::Decode(e.to_string()))
470    }
471
472    pub async fn text(self) -> Result<String, RequestError> {
473        self.response.text().await
474    }
475
476    /// Creates the `http::response::Parts` from this response.
477    pub fn make_parts(&self) -> Parts {
478        let mut response = http::response::Response::builder().status(self.response.status());
479
480        response = response.version(self.response.version());
481
482        for (key, value) in self.response.headers().iter() {
483            response = response.header(key, value);
484        }
485
486        let (parts, _) = response.body(()).unwrap().into_parts();
487
488        parts
489    }
490
491    /// Consumes the response, returning the head and a stream of the body.
492    pub fn into_parts(self) -> (Parts, impl Stream<Item = Result<Bytes, StreamingError>>) {
493        (self.make_parts(), self.bytes_stream())
494    }
495}
496
497/// Set the root server URL that all server function paths are relative to for the client.
498///
499/// If this is not set, it defaults to the origin.
500pub fn set_server_url(url: &'static str) {
501    ROOT_URL.set(url).unwrap();
502}
503
504/// Returns the root server URL for all server functions.
505pub fn get_server_url() -> &'static str {
506    ROOT_URL.get().copied().unwrap_or("")
507}
508
509static ROOT_URL: OnceLock<&'static str> = OnceLock::new();
510
511/// Delete the extra request headers for all servers functions.
512pub fn clear_request_headers() {
513    REQUEST_HEADERS.lock().unwrap().clear();
514}
515
516/// Set the extra request headers for all servers functions.
517pub fn set_request_headers(headers: HeaderMap) {
518    *REQUEST_HEADERS.lock().unwrap() = headers;
519}
520
521/// Returns the extra request headers for all servers functions.
522pub fn get_request_headers() -> HeaderMap {
523    REQUEST_HEADERS.lock().unwrap().clone()
524}
525
526static REQUEST_HEADERS: LazyLock<Mutex<HeaderMap>> = LazyLock::new(|| Mutex::new(HeaderMap::new()));
527
528pub trait ClientResponseDriver {
529    fn status(&self) -> StatusCode;
530    fn headers(&self) -> &HeaderMap;
531    fn url(&self) -> &Url;
532    fn version(&self) -> http::Version {
533        http::Version::HTTP_2
534    }
535    fn content_length(&self) -> Option<u64>;
536    fn bytes(self: Box<Self>) -> Pin<Box<dyn Future<Output = Result<Bytes, RequestError>> + Send>>;
537    fn bytes_stream(
538        self: Box<Self>,
539    ) -> Pin<Box<dyn Stream<Item = Result<Bytes, StreamingError>> + 'static + Unpin + Send>>;
540
541    fn text(self: Box<Self>) -> Pin<Box<dyn Future<Output = Result<String, RequestError>> + Send>>;
542}
543
544mod native {
545    use futures::Stream;
546
547    use super::*;
548
549    impl ClientResponseDriver for reqwest::Response {
550        fn status(&self) -> http::StatusCode {
551            reqwest::Response::status(self)
552        }
553
554        fn version(&self) -> http::Version {
555            #[cfg(target_arch = "wasm32")]
556            {
557                return http::Version::HTTP_2;
558            }
559
560            reqwest::Response::version(self)
561        }
562
563        fn headers(&self) -> &http::HeaderMap {
564            reqwest::Response::headers(self)
565        }
566
567        fn url(&self) -> &url::Url {
568            reqwest::Response::url(self)
569        }
570
571        fn content_length(&self) -> Option<u64> {
572            reqwest::Response::content_length(self)
573        }
574
575        fn bytes(
576            self: Box<Self>,
577        ) -> Pin<Box<dyn Future<Output = Result<Bytes, RequestError>> + Send>> {
578            Box::pin(SendWrapper::new(async move {
579                reqwest::Response::bytes(*self)
580                    .map_err(reqwest_error_to_request_error)
581                    .await
582            }))
583        }
584
585        fn bytes_stream(
586            self: Box<Self>,
587        ) -> Pin<Box<dyn Stream<Item = Result<Bytes, StreamingError>> + 'static + Unpin + Send>>
588        {
589            Box::pin(SendWrapper::new(
590                reqwest::Response::bytes_stream(*self).map_err(|_| StreamingError::Failed),
591            ))
592        }
593
594        fn text(
595            self: Box<Self>,
596        ) -> Pin<Box<dyn Future<Output = Result<String, RequestError>> + Send>> {
597            Box::pin(SendWrapper::new(async move {
598                reqwest::Response::text(*self)
599                    .map_err(reqwest_error_to_request_error)
600                    .await
601            }))
602        }
603    }
604}
605
606#[cfg(feature = "web")]
607mod browser {
608    use crate::{ClientResponseDriver, StreamingError};
609    use bytes::Bytes;
610    use dioxus_fullstack_core::RequestError;
611    use futures::{Stream, StreamExt};
612    use http::{HeaderMap, StatusCode};
613    use js_sys::Uint8Array;
614    use send_wrapper::SendWrapper;
615    use std::{pin::Pin, prelude::rust_2024::Future};
616    use wasm_bindgen::JsCast;
617
618    pub(crate) struct WrappedGlooResponse {
619        pub(crate) inner: gloo_net::http::Response,
620        pub(crate) headers: HeaderMap,
621        pub(crate) status: StatusCode,
622        pub(crate) url: url::Url,
623        pub(crate) content_length: Option<u64>,
624    }
625
626    impl ClientResponseDriver for WrappedGlooResponse {
627        fn status(&self) -> StatusCode {
628            self.status
629        }
630
631        fn headers(&self) -> &HeaderMap {
632            &self.headers
633        }
634
635        fn url(&self) -> &url::Url {
636            &self.url
637        }
638
639        fn content_length(&self) -> Option<u64> {
640            self.content_length
641        }
642
643        fn bytes(
644            self: Box<Self>,
645        ) -> Pin<Box<dyn Future<Output = Result<Bytes, RequestError>> + Send>> {
646            Box::pin(SendWrapper::new(async move {
647                let bytes = self
648                    .inner
649                    .binary()
650                    .await
651                    .map_err(|e| RequestError::Request(e.to_string()))?;
652                Ok(bytes.into())
653            }))
654        }
655
656        fn bytes_stream(
657            self: Box<Self>,
658        ) -> Pin<Box<dyn Stream<Item = Result<Bytes, StreamingError>> + 'static + Unpin + Send>>
659        {
660            let body = match self.inner.body() {
661                Some(body) => body,
662                None => {
663                    return Box::pin(SendWrapper::new(futures::stream::iter([Err(
664                        StreamingError::Failed,
665                    )])));
666                }
667            };
668
669            Box::pin(SendWrapper::new(
670                wasm_streams::ReadableStream::from_raw(body)
671                    .into_stream()
672                    .map(|chunk| {
673                        let array = chunk
674                            .map_err(|_| StreamingError::Failed)?
675                            .dyn_into::<Uint8Array>()
676                            .map_err(|_| StreamingError::Failed)?;
677                        Ok(array.to_vec().into())
678                    }),
679            ))
680        }
681
682        fn text(
683            self: Box<Self>,
684        ) -> Pin<Box<dyn Future<Output = Result<String, RequestError>> + Send>> {
685            Box::pin(SendWrapper::new(async move {
686                self.inner
687                    .text()
688                    .await
689                    .map_err(|e| RequestError::Request(e.to_string()))
690            }))
691        }
692    }
693}