1use blitz_traits::net::{
6 Body, BoxedHandler, Bytes, NetCallback, NetProvider, Request, SharedCallback,
7};
8use data_url::DataUrl;
9use reqwest::Client;
10use std::sync::Arc;
11use tokio::{
12 runtime::Handle,
13 sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
14};
15
16const USER_AGENT: &str = "Mozilla/5.0 (X11; Linux x86_64; rv:60.0) Gecko/20100101 Firefox/81.0";
17
18pub struct Provider<D> {
19 rt: Handle,
20 client: Client,
21 resource_callback: SharedCallback<D>,
22}
23impl<D: 'static> Provider<D> {
24 pub fn new(resource_callback: SharedCallback<D>) -> Self {
25 #[cfg(feature = "cookies")]
26 let client = Client::builder().cookie_store(true).build().unwrap();
27 #[cfg(not(feature = "cookies"))]
28 let client = Client::new();
29
30 Self {
31 rt: Handle::current(),
32 client,
33 resource_callback,
34 }
35 }
36 pub fn shared(res_callback: SharedCallback<D>) -> Arc<dyn NetProvider<D>> {
37 Arc::new(Self::new(res_callback))
38 }
39 pub fn is_empty(&self) -> bool {
40 Arc::strong_count(&self.resource_callback) == 1
41 }
42 pub fn count(&self) -> usize {
43 Arc::strong_count(&self.resource_callback) - 1
44 }
45}
46impl<D: 'static> Provider<D> {
47 async fn fetch_inner(
48 client: Client,
49 request: Request,
50 ) -> Result<(String, Bytes), ProviderError> {
51 Ok(match request.url.scheme() {
52 "data" => {
53 let data_url = DataUrl::process(request.url.as_str())?;
54 let decoded = data_url.decode_to_vec()?;
55 (request.url.to_string(), Bytes::from(decoded.0))
56 }
57 "file" => {
58 let file_content = std::fs::read(request.url.path())?;
59 (request.url.to_string(), Bytes::from(file_content))
60 }
61 _ => {
62 let response = client
63 .request(request.method, request.url)
64 .headers(request.headers)
65 .header("Content-Type", request.content_type.as_str())
66 .header("User-Agent", USER_AGENT)
67 .apply_body(request.body, request.content_type.as_str())
68 .await
69 .send()
70 .await?;
71
72 (response.url().to_string(), response.bytes().await?)
73 }
74 })
75 }
76
77 async fn fetch_with_handler(
78 client: Client,
79 doc_id: usize,
80 request: Request,
81 handler: BoxedHandler<D>,
82 res_callback: SharedCallback<D>,
83 ) -> Result<(), ProviderError> {
84 let (_response_url, bytes) = Self::fetch_inner(client, request).await?;
85 handler.bytes(doc_id, bytes, res_callback);
86 Ok(())
87 }
88
89 #[allow(clippy::type_complexity)]
90 pub fn fetch_with_callback(
91 &self,
92 request: Request,
93 callback: Box<dyn FnOnce(Result<(String, Bytes), ProviderError>) + Send + Sync + 'static>,
94 ) {
95 #[cfg(feature = "debug_log")]
96 let url = request.url.to_string();
97
98 let client = self.client.clone();
99 self.rt.spawn(async move {
100 let result = Self::fetch_inner(client, request).await;
101
102 #[cfg(feature = "debug_log")]
103 if let Err(e) = &result {
104 eprintln!("Error fetching {url}: {e:?}");
105 } else {
106 println!("Success {url}");
107 }
108
109 callback(result);
110 });
111 }
112
113 pub async fn fetch_async(&self, request: Request) -> Result<(String, Bytes), ProviderError> {
114 #[cfg(feature = "debug_log")]
115 let url = request.url.to_string();
116
117 let client = self.client.clone();
118 let result = Self::fetch_inner(client, request).await;
119
120 #[cfg(feature = "debug_log")]
121 if let Err(e) = &result {
122 eprintln!("Error fetching {url}: {e:?}");
123 } else {
124 println!("Success {url}");
125 }
126
127 result
128 }
129}
130
131impl<D: 'static> NetProvider<D> for Provider<D> {
132 fn fetch(&self, doc_id: usize, request: Request, handler: BoxedHandler<D>) {
133 let client = self.client.clone();
134 let callback = Arc::clone(&self.resource_callback);
135
136 #[cfg(feature = "debug_log")]
137 println!("Fetching {}", &request.url);
138
139 self.rt.spawn(async move {
140 #[cfg(feature = "debug_log")]
141 let url = request.url.to_string();
142
143 let _res = Self::fetch_with_handler(client, doc_id, request, handler, callback).await;
144
145 #[cfg(feature = "debug_log")]
146 if let Err(e) = _res {
147 eprintln!("Error fetching {url}: {e:?}");
148 } else {
149 println!("Success {url}");
150 }
151 });
152 }
153}
154
155#[derive(Debug)]
156pub enum ProviderError {
157 Io(std::io::Error),
158 DataUrl(data_url::DataUrlError),
159 DataUrlBase64(data_url::forgiving_base64::InvalidBase64),
160 ReqwestError(reqwest::Error),
161}
162
163impl From<std::io::Error> for ProviderError {
164 fn from(value: std::io::Error) -> Self {
165 Self::Io(value)
166 }
167}
168
169impl From<data_url::DataUrlError> for ProviderError {
170 fn from(value: data_url::DataUrlError) -> Self {
171 Self::DataUrl(value)
172 }
173}
174
175impl From<data_url::forgiving_base64::InvalidBase64> for ProviderError {
176 fn from(value: data_url::forgiving_base64::InvalidBase64) -> Self {
177 Self::DataUrlBase64(value)
178 }
179}
180
181impl From<reqwest::Error> for ProviderError {
182 fn from(value: reqwest::Error) -> Self {
183 Self::ReqwestError(value)
184 }
185}
186
187pub struct MpscCallback<T>(UnboundedSender<(usize, T)>);
188impl<T> MpscCallback<T> {
189 pub fn new() -> (UnboundedReceiver<(usize, T)>, Self) {
190 let (send, recv) = unbounded_channel();
191 (recv, Self(send))
192 }
193}
194impl<T: Send + Sync + 'static> NetCallback<T> for MpscCallback<T> {
195 fn call(&self, doc_id: usize, result: Result<T, Option<String>>) {
196 if let Ok(data) = result {
198 let _ = self.0.send((doc_id, data));
199 }
200 }
201}
202
203trait ReqwestExt {
204 async fn apply_body(self, body: Body, content_type: &str) -> Self;
205}
206impl ReqwestExt for reqwest::RequestBuilder {
207 async fn apply_body(self, body: Body, content_type: &str) -> Self {
208 match body {
209 Body::Bytes(bytes) => self.body(bytes),
210 Body::Form(form_data) => match content_type {
211 "application/x-www-form-urlencoded" => self.form(&form_data),
212 #[cfg(feature = "multipart")]
213 "multipart/form-data" => {
214 use blitz_traits::net::Entry;
215 use blitz_traits::net::EntryValue;
216 let mut form_data = form_data;
217 let mut form = reqwest::multipart::Form::new();
218 for Entry { name, value } in form_data.0.drain(..) {
219 form = match value {
220 EntryValue::String(value) => form.text(name, value),
221 EntryValue::File(path_buf) => form
222 .file(name, path_buf)
223 .await
224 .expect("Couldn't read form file from disk"),
225 EntryValue::EmptyFile => form.part(
226 name,
227 reqwest::multipart::Part::bytes(&[])
228 .mime_str("application/octet-stream")
229 .unwrap(),
230 ),
231 };
232 }
233 self.multipart(form)
234 }
235 _ => self,
236 },
237 Body::Empty => self,
238 }
239 }
240}