Skip to main content

blitz_net/
lib.rs

1//! Networking (HTTP, filesystem, Data URIs) for Blitz
2//!
3//! Provides an implementation of the [`blitz_traits::net::NetProvider`] trait.
4
5// use blitz_traits::net::{Body, Bytes, NetHandler, NetProvider, NetWaker, Request};
6use blitz_traits::net::{AbortSignal, Body, Bytes, NetHandler, NetProvider, NetWaker, Request};
7use data_url::DataUrl;
8use std::{marker::PhantomData, pin::Pin, sync::Arc, task::Poll};
9
10#[cfg(feature = "cache")]
11use http_cache_reqwest::{CACacheManager, Cache, CacheMode, HttpCache, HttpCacheOptions};
12
13const USER_AGENT: &str = "Mozilla/5.0 (X11; Linux x86_64; rv:60.0) Gecko/20100101 Firefox/81.0";
14
15#[cfg(feature = "cache")]
16type Client = reqwest_middleware::ClientWithMiddleware;
17#[cfg(not(feature = "cache"))]
18type Client = reqwest::Client;
19
20#[cfg(feature = "cache")]
21type RequestBuilder = reqwest_middleware::RequestBuilder;
22#[cfg(not(feature = "cache"))]
23type RequestBuilder = reqwest::RequestBuilder;
24
25#[cfg(feature = "cache")]
26fn get_cache_path() -> std::path::PathBuf {
27    use directories::ProjectDirs;
28    let path = ProjectDirs::from("com", "DioxusLabs", "Blitz")
29        .expect("Failed to find cache directory")
30        .cache_dir()
31        .to_owned();
32    #[cfg(feature = "tracing")]
33    tracing::info!(path = ?path.display(), "Using cache dir");
34    path
35}
36
37#[cfg(target_arch = "wasm32")]
38fn spawn(fut: impl Future + 'static) {
39    wasm_bindgen_futures::spawn_local(async move {
40        fut.await;
41    });
42}
43
44#[cfg(not(target_arch = "wasm32"))]
45fn spawn<F>(fut: F)
46where
47    F: Future + Send + 'static,
48    F::Output: Send + 'static,
49{
50    tokio::spawn(fut);
51}
52
53pub struct Provider {
54    client: Client,
55    waker: Arc<dyn NetWaker>,
56}
57impl Provider {
58    pub fn new(waker: Option<Arc<dyn NetWaker>>) -> Self {
59        let builder = reqwest::Client::builder();
60        #[cfg(feature = "cookies")]
61        let builder = builder.cookie_store(true);
62        let client = builder.build().unwrap();
63
64        #[cfg(feature = "cache")]
65        let client = reqwest_middleware::ClientBuilder::new(client)
66            .with(Cache(HttpCache {
67                mode: CacheMode::Default,
68                manager: CACacheManager::new(get_cache_path(), true),
69                options: HttpCacheOptions::default(),
70            }))
71            .build();
72
73        let waker = waker.unwrap_or(Arc::new(DummyNetWaker));
74        Self { client, waker }
75    }
76    pub fn shared(waker: Option<Arc<dyn NetWaker>>) -> Arc<dyn NetProvider> {
77        Arc::new(Self::new(waker))
78    }
79    pub fn is_empty(&self) -> bool {
80        Arc::strong_count(&self.waker) == 1
81    }
82    pub fn count(&self) -> usize {
83        Arc::strong_count(&self.waker) - 1
84    }
85}
86impl Provider {
87    async fn fetch_inner(
88        client: Client,
89        request: Request,
90    ) -> Result<(String, Bytes), ProviderError> {
91        Ok(match request.url.scheme() {
92            "data" => {
93                let data_url = DataUrl::process(request.url.as_str())?;
94                let decoded = data_url.decode_to_vec()?;
95                (request.url.to_string(), Bytes::from(decoded.0))
96            }
97            "file" => {
98                let file_content = std::fs::read(request.url.path())?;
99                (request.url.to_string(), Bytes::from(file_content))
100            }
101            _ => {
102                let response = client
103                    .request(request.method, request.url)
104                    .headers(request.headers)
105                    .header("Content-Type", request.content_type.as_str())
106                    .header("User-Agent", USER_AGENT)
107                    .apply_body(request.body, request.content_type.as_str())
108                    .await
109                    .send()
110                    .await?;
111
112                (response.url().to_string(), response.bytes().await?)
113            }
114        })
115    }
116
117    #[allow(clippy::type_complexity)]
118    pub fn fetch_with_callback(
119        &self,
120        request: Request,
121        callback: Box<dyn FnOnce(Result<(String, Bytes), ProviderError>) + Send + Sync + 'static>,
122    ) {
123        #[cfg(feature = "tracing")]
124        let url = request.url.to_string();
125
126        let client = self.client.clone();
127        spawn(async move {
128            let result = Self::fetch_inner(client, request).await;
129
130            #[cfg(feature = "tracing")]
131            if let Err(e) = &result {
132                #[cfg(feature = "tracing")]
133                tracing::error!(url = url.as_str(), error = ?e, "Fetching");
134            } else {
135                #[cfg(feature = "tracing")]
136                tracing::info!(url = url.as_str(), "Success fetching");
137            }
138
139            callback(result);
140        });
141    }
142
143    pub async fn fetch_async(&self, request: Request) -> Result<(String, Bytes), ProviderError> {
144        #[cfg(feature = "tracing")]
145        let url = request.url.to_string();
146
147        let client = self.client.clone();
148        let result = Self::fetch_inner(client, request).await;
149
150        #[cfg(feature = "tracing")]
151        if let Err(e) = &result {
152            #[cfg(feature = "tracing")]
153            tracing::error!(url = url.as_str(), error = ?e, "Fetching");
154        } else {
155            #[cfg(feature = "tracing")]
156            tracing::info!(url = url.as_str(), "Success fetching");
157        }
158
159        result
160    }
161}
162
163impl NetProvider for Provider {
164    fn fetch(&self, doc_id: usize, mut request: Request, handler: Box<dyn NetHandler>) {
165        let client = self.client.clone();
166
167        #[cfg(feature = "tracing")]
168        tracing::info!(url = request.url.as_str(), "Fetching");
169
170        let waker = self.waker.clone();
171        spawn(async move {
172            #[cfg(feature = "tracing")]
173            let url = request.url.to_string();
174
175            let signal = request.signal.take();
176            let result = if let Some(signal) = signal {
177                AbortFetch::new(
178                    signal,
179                    Box::pin(async move { Self::fetch_inner(client, request).await }),
180                )
181                .await
182            } else {
183                Self::fetch_inner(client, request).await
184            };
185
186            // Call the waker to notify of completed network request
187            waker.wake(doc_id);
188
189            match result {
190                Ok((response_url, bytes)) => {
191                    handler.bytes(response_url, bytes);
192                    #[cfg(feature = "tracing")]
193                    tracing::info!(url = url.as_str(), "Success fetching");
194                }
195                Err(e) => {
196                    #[cfg(feature = "tracing")]
197                    tracing::error!(url = url.as_str(), error = ?e, "Error fetching");
198                    #[cfg(not(feature = "tracing"))]
199                    let _ = e;
200                }
201            };
202        });
203    }
204}
205
206/// A future that is cancellable using an AbortSignal
207struct AbortFetch<F, T> {
208    signal: AbortSignal,
209    future: F,
210    _rt: PhantomData<T>,
211}
212
213impl<F, T> AbortFetch<F, T> {
214    fn new(signal: AbortSignal, future: F) -> Self {
215        Self {
216            signal,
217            future,
218            _rt: PhantomData,
219        }
220    }
221}
222
223impl<F, T> Future for AbortFetch<F, T>
224where
225    F: Future + Unpin + 'static,
226    F::Output: Into<Result<T, ProviderError>> + 'static,
227    T: Unpin,
228{
229    type Output = Result<T, ProviderError>;
230
231    fn poll(
232        mut self: std::pin::Pin<&mut Self>,
233        cx: &mut std::task::Context<'_>,
234    ) -> std::task::Poll<Self::Output> {
235        if self.signal.aborted() {
236            return Poll::Ready(Err(ProviderError::Abort));
237        }
238
239        match Pin::new(&mut self.future).poll(cx) {
240            Poll::Ready(output) => Poll::Ready(output.into()),
241            Poll::Pending => Poll::Pending,
242        }
243    }
244}
245
246#[derive(Debug)]
247pub enum ProviderError {
248    Abort,
249    Io(std::io::Error),
250    DataUrl(data_url::DataUrlError),
251    DataUrlBase64(data_url::forgiving_base64::InvalidBase64),
252    ReqwestError(reqwest::Error),
253    #[cfg(feature = "cache")]
254    ReqwestMiddlewareError(reqwest_middleware::Error),
255}
256
257impl From<std::io::Error> for ProviderError {
258    fn from(value: std::io::Error) -> Self {
259        Self::Io(value)
260    }
261}
262
263impl From<data_url::DataUrlError> for ProviderError {
264    fn from(value: data_url::DataUrlError) -> Self {
265        Self::DataUrl(value)
266    }
267}
268
269impl From<data_url::forgiving_base64::InvalidBase64> for ProviderError {
270    fn from(value: data_url::forgiving_base64::InvalidBase64) -> Self {
271        Self::DataUrlBase64(value)
272    }
273}
274
275impl From<reqwest::Error> for ProviderError {
276    fn from(value: reqwest::Error) -> Self {
277        Self::ReqwestError(value)
278    }
279}
280
281#[cfg(feature = "cache")]
282impl From<reqwest_middleware::Error> for ProviderError {
283    fn from(value: reqwest_middleware::Error) -> Self {
284        Self::ReqwestMiddlewareError(value)
285    }
286}
287
288trait ReqwestExt {
289    async fn apply_body(self, body: Body, content_type: &str) -> Self;
290}
291impl ReqwestExt for RequestBuilder {
292    async fn apply_body(self, body: Body, content_type: &str) -> Self {
293        match body {
294            Body::Bytes(bytes) => self.body(bytes),
295            Body::Form(form_data) => match content_type {
296                "application/x-www-form-urlencoded" => self.form(&form_data),
297                #[cfg(feature = "multipart")]
298                "multipart/form-data" => {
299                    use blitz_traits::net::Entry;
300                    use blitz_traits::net::EntryValue;
301                    let mut form_data = form_data;
302                    let mut form = reqwest::multipart::Form::new();
303                    for Entry { name, value } in form_data.0.drain(..) {
304                        form = match value {
305                            EntryValue::String(value) => form.text(name, value),
306                            EntryValue::File(path_buf) => form
307                                .file(name, path_buf)
308                                .await
309                                .expect("Couldn't read form file from disk"),
310                            EntryValue::EmptyFile => form.part(
311                                name,
312                                reqwest::multipart::Part::bytes(&[])
313                                    .mime_str("application/octet-stream")
314                                    .unwrap(),
315                            ),
316                        };
317                    }
318                    self.multipart(form)
319                }
320                _ => self,
321            },
322            Body::Empty => self,
323        }
324    }
325}
326
327struct DummyNetWaker;
328impl NetWaker for DummyNetWaker {
329    fn wake(&self, _client_id: usize) {}
330}