1use blitz_traits::net::{AbortSignal, Body, Bytes, NetHandler, NetProvider, NetWaker, Request};
7use data_url::DataUrl;
8use std::{marker::PhantomData, pin::Pin, sync::Arc, task::Poll};
9
10#[cfg(feature = "cache")]
11use http_cache_reqwest::{CACacheManager, Cache, CacheMode, HttpCache, HttpCacheOptions};
12
13const USER_AGENT: &str = "Mozilla/5.0 (X11; Linux x86_64; rv:60.0) Gecko/20100101 Firefox/81.0";
14
15#[cfg(feature = "cache")]
16type Client = reqwest_middleware::ClientWithMiddleware;
17#[cfg(not(feature = "cache"))]
18type Client = reqwest::Client;
19
20#[cfg(feature = "cache")]
21type RequestBuilder = reqwest_middleware::RequestBuilder;
22#[cfg(not(feature = "cache"))]
23type RequestBuilder = reqwest::RequestBuilder;
24
25#[cfg(feature = "cache")]
26fn get_cache_path() -> std::path::PathBuf {
27 use directories::ProjectDirs;
28 let path = ProjectDirs::from("com", "DioxusLabs", "Blitz")
29 .expect("Failed to find cache directory")
30 .cache_dir()
31 .to_owned();
32 #[cfg(feature = "tracing")]
33 tracing::info!(path = ?path.display(), "Using cache dir");
34 path
35}
36
37#[cfg(target_arch = "wasm32")]
38fn spawn(fut: impl Future + 'static) {
39 wasm_bindgen_futures::spawn_local(async move {
40 fut.await;
41 });
42}
43
44#[cfg(not(target_arch = "wasm32"))]
45fn spawn<F>(fut: F)
46where
47 F: Future + Send + 'static,
48 F::Output: Send + 'static,
49{
50 tokio::spawn(fut);
51}
52
53pub struct Provider {
54 client: Client,
55 waker: Arc<dyn NetWaker>,
56}
57impl Provider {
58 pub fn new(waker: Option<Arc<dyn NetWaker>>) -> Self {
59 let builder = reqwest::Client::builder();
60 #[cfg(feature = "cookies")]
61 let builder = builder.cookie_store(true);
62 let client = builder.build().unwrap();
63
64 #[cfg(feature = "cache")]
65 let client = reqwest_middleware::ClientBuilder::new(client)
66 .with(Cache(HttpCache {
67 mode: CacheMode::Default,
68 manager: CACacheManager::new(get_cache_path(), true),
69 options: HttpCacheOptions::default(),
70 }))
71 .build();
72
73 let waker = waker.unwrap_or(Arc::new(DummyNetWaker));
74 Self { client, waker }
75 }
76 pub fn shared(waker: Option<Arc<dyn NetWaker>>) -> Arc<dyn NetProvider> {
77 Arc::new(Self::new(waker))
78 }
79 pub fn is_empty(&self) -> bool {
80 Arc::strong_count(&self.waker) == 1
81 }
82 pub fn count(&self) -> usize {
83 Arc::strong_count(&self.waker) - 1
84 }
85}
86impl Provider {
87 async fn fetch_inner(
88 client: Client,
89 request: Request,
90 ) -> Result<(String, Bytes), ProviderError> {
91 Ok(match request.url.scheme() {
92 "data" => {
93 let data_url = DataUrl::process(request.url.as_str())?;
94 let decoded = data_url.decode_to_vec()?;
95 (request.url.to_string(), Bytes::from(decoded.0))
96 }
97 "file" => {
98 let file_content = std::fs::read(request.url.path())?;
99 (request.url.to_string(), Bytes::from(file_content))
100 }
101 _ => {
102 let response = client
103 .request(request.method, request.url)
104 .headers(request.headers)
105 .header("Content-Type", request.content_type.as_str())
106 .header("User-Agent", USER_AGENT)
107 .apply_body(request.body, request.content_type.as_str())
108 .await
109 .send()
110 .await?;
111
112 (response.url().to_string(), response.bytes().await?)
113 }
114 })
115 }
116
117 #[allow(clippy::type_complexity)]
118 pub fn fetch_with_callback(
119 &self,
120 request: Request,
121 callback: Box<dyn FnOnce(Result<(String, Bytes), ProviderError>) + Send + Sync + 'static>,
122 ) {
123 #[cfg(feature = "tracing")]
124 let url = request.url.to_string();
125
126 let client = self.client.clone();
127 spawn(async move {
128 let result = Self::fetch_inner(client, request).await;
129
130 #[cfg(feature = "tracing")]
131 if let Err(e) = &result {
132 #[cfg(feature = "tracing")]
133 tracing::error!(url = url.as_str(), error = ?e, "Fetching");
134 } else {
135 #[cfg(feature = "tracing")]
136 tracing::info!(url = url.as_str(), "Success fetching");
137 }
138
139 callback(result);
140 });
141 }
142
143 pub async fn fetch_async(&self, request: Request) -> Result<(String, Bytes), ProviderError> {
144 #[cfg(feature = "tracing")]
145 let url = request.url.to_string();
146
147 let client = self.client.clone();
148 let result = Self::fetch_inner(client, request).await;
149
150 #[cfg(feature = "tracing")]
151 if let Err(e) = &result {
152 #[cfg(feature = "tracing")]
153 tracing::error!(url = url.as_str(), error = ?e, "Fetching");
154 } else {
155 #[cfg(feature = "tracing")]
156 tracing::info!(url = url.as_str(), "Success fetching");
157 }
158
159 result
160 }
161}
162
163impl NetProvider for Provider {
164 fn fetch(&self, doc_id: usize, mut request: Request, handler: Box<dyn NetHandler>) {
165 let client = self.client.clone();
166
167 #[cfg(feature = "tracing")]
168 tracing::info!(url = request.url.as_str(), "Fetching");
169
170 let waker = self.waker.clone();
171 spawn(async move {
172 #[cfg(feature = "tracing")]
173 let url = request.url.to_string();
174
175 let signal = request.signal.take();
176 let result = if let Some(signal) = signal {
177 AbortFetch::new(
178 signal,
179 Box::pin(async move { Self::fetch_inner(client, request).await }),
180 )
181 .await
182 } else {
183 Self::fetch_inner(client, request).await
184 };
185
186 waker.wake(doc_id);
188
189 match result {
190 Ok((response_url, bytes)) => {
191 handler.bytes(response_url, bytes);
192 #[cfg(feature = "tracing")]
193 tracing::info!(url = url.as_str(), "Success fetching");
194 }
195 Err(e) => {
196 #[cfg(feature = "tracing")]
197 tracing::error!(url = url.as_str(), error = ?e, "Error fetching");
198 #[cfg(not(feature = "tracing"))]
199 let _ = e;
200 }
201 };
202 });
203 }
204}
205
206struct AbortFetch<F, T> {
208 signal: AbortSignal,
209 future: F,
210 _rt: PhantomData<T>,
211}
212
213impl<F, T> AbortFetch<F, T> {
214 fn new(signal: AbortSignal, future: F) -> Self {
215 Self {
216 signal,
217 future,
218 _rt: PhantomData,
219 }
220 }
221}
222
223impl<F, T> Future for AbortFetch<F, T>
224where
225 F: Future + Unpin + 'static,
226 F::Output: Into<Result<T, ProviderError>> + 'static,
227 T: Unpin,
228{
229 type Output = Result<T, ProviderError>;
230
231 fn poll(
232 mut self: std::pin::Pin<&mut Self>,
233 cx: &mut std::task::Context<'_>,
234 ) -> std::task::Poll<Self::Output> {
235 if self.signal.aborted() {
236 return Poll::Ready(Err(ProviderError::Abort));
237 }
238
239 match Pin::new(&mut self.future).poll(cx) {
240 Poll::Ready(output) => Poll::Ready(output.into()),
241 Poll::Pending => Poll::Pending,
242 }
243 }
244}
245
246#[derive(Debug)]
247pub enum ProviderError {
248 Abort,
249 Io(std::io::Error),
250 DataUrl(data_url::DataUrlError),
251 DataUrlBase64(data_url::forgiving_base64::InvalidBase64),
252 ReqwestError(reqwest::Error),
253 #[cfg(feature = "cache")]
254 ReqwestMiddlewareError(reqwest_middleware::Error),
255}
256
257impl From<std::io::Error> for ProviderError {
258 fn from(value: std::io::Error) -> Self {
259 Self::Io(value)
260 }
261}
262
263impl From<data_url::DataUrlError> for ProviderError {
264 fn from(value: data_url::DataUrlError) -> Self {
265 Self::DataUrl(value)
266 }
267}
268
269impl From<data_url::forgiving_base64::InvalidBase64> for ProviderError {
270 fn from(value: data_url::forgiving_base64::InvalidBase64) -> Self {
271 Self::DataUrlBase64(value)
272 }
273}
274
275impl From<reqwest::Error> for ProviderError {
276 fn from(value: reqwest::Error) -> Self {
277 Self::ReqwestError(value)
278 }
279}
280
281#[cfg(feature = "cache")]
282impl From<reqwest_middleware::Error> for ProviderError {
283 fn from(value: reqwest_middleware::Error) -> Self {
284 Self::ReqwestMiddlewareError(value)
285 }
286}
287
288trait ReqwestExt {
289 async fn apply_body(self, body: Body, content_type: &str) -> Self;
290}
291impl ReqwestExt for RequestBuilder {
292 async fn apply_body(self, body: Body, content_type: &str) -> Self {
293 match body {
294 Body::Bytes(bytes) => self.body(bytes),
295 Body::Form(form_data) => match content_type {
296 "application/x-www-form-urlencoded" => self.form(&form_data),
297 #[cfg(feature = "multipart")]
298 "multipart/form-data" => {
299 use blitz_traits::net::Entry;
300 use blitz_traits::net::EntryValue;
301 let mut form_data = form_data;
302 let mut form = reqwest::multipart::Form::new();
303 for Entry { name, value } in form_data.0.drain(..) {
304 form = match value {
305 EntryValue::String(value) => form.text(name, value),
306 EntryValue::File(path_buf) => form
307 .file(name, path_buf)
308 .await
309 .expect("Couldn't read form file from disk"),
310 EntryValue::EmptyFile => form.part(
311 name,
312 reqwest::multipart::Part::bytes(&[])
313 .mime_str("application/octet-stream")
314 .unwrap(),
315 ),
316 };
317 }
318 self.multipart(form)
319 }
320 _ => self,
321 },
322 Body::Empty => self,
323 }
324 }
325}
326
327struct DummyNetWaker;
328impl NetWaker for DummyNetWaker {
329 fn wake(&self, _client_id: usize) {}
330}