1use blitz_traits::net::{
6 Body, BoxedHandler, Bytes, NetCallback, NetProvider, Request, SharedCallback,
7};
8use data_url::DataUrl;
9use reqwest::Client;
10use std::sync::Arc;
11use tokio::{
12 runtime::Handle,
13 sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
14};
15
16const USER_AGENT: &str = "Mozilla/5.0 (X11; Linux x86_64; rv:60.0) Gecko/20100101 Firefox/81.0";
17
18pub struct Provider<D> {
19 rt: Handle,
20 client: Client,
21 resource_callback: SharedCallback<D>,
22}
23impl<D: 'static> Provider<D> {
24 pub fn new(resource_callback: SharedCallback<D>) -> Self {
25 #[cfg(feature = "cookies")]
26 let client = Client::builder().cookie_store(true).build().unwrap();
27 #[cfg(not(feature = "cookies"))]
28 let client = Client::new();
29
30 Self {
31 rt: Handle::current(),
32 client,
33 resource_callback,
34 }
35 }
36 pub fn shared(res_callback: SharedCallback<D>) -> Arc<dyn NetProvider<D>> {
37 Arc::new(Self::new(res_callback))
38 }
39 pub fn is_empty(&self) -> bool {
40 Arc::strong_count(&self.resource_callback) == 1
41 }
42 pub fn count(&self) -> usize {
43 Arc::strong_count(&self.resource_callback) - 1
44 }
45}
46impl<D: 'static> Provider<D> {
47 async fn fetch_inner(
48 client: Client,
49 request: Request,
50 ) -> Result<(String, Bytes), ProviderError> {
51 Ok(match request.url.scheme() {
52 "data" => {
53 let data_url = DataUrl::process(request.url.as_str())?;
54 let decoded = data_url.decode_to_vec()?;
55 (request.url.to_string(), Bytes::from(decoded.0))
56 }
57 "file" => {
58 let file_content = std::fs::read(request.url.path())?;
59 (request.url.to_string(), Bytes::from(file_content))
60 }
61 _ => {
62 let response = client
63 .request(request.method, request.url)
64 .headers(request.headers)
65 .header("Content-Type", request.content_type.as_str())
66 .header("User-Agent", USER_AGENT)
67 .apply_body(request.body, request.content_type.as_str())
68 .await
69 .send()
70 .await?;
71
72 (response.url().to_string(), response.bytes().await?)
73 }
74 })
75 }
76
77 async fn fetch_with_handler(
78 client: Client,
79 doc_id: usize,
80 request: Request,
81 handler: BoxedHandler<D>,
82 res_callback: SharedCallback<D>,
83 ) -> Result<(), ProviderError> {
84 let (_response_url, bytes) = Self::fetch_inner(client, request).await?;
85 handler.bytes(doc_id, bytes, res_callback);
86 Ok(())
87 }
88
89 #[allow(clippy::type_complexity)]
90 pub fn fetch_with_callback(
91 &self,
92 request: Request,
93 callback: Box<dyn FnOnce(Result<(String, Bytes), ProviderError>) + Send + Sync + 'static>,
94 ) {
95 let client = self.client.clone();
96 self.rt.spawn(async move {
97 let url = request.url.to_string();
98 let result = Self::fetch_inner(client, request).await;
99 if let Err(e) = &result {
100 eprintln!("Error fetching {url}: {e:?}");
101 } else {
102 println!("Success {url}");
103 }
104 callback(result);
105 });
106 }
107
108 pub async fn fetch_async(&self, request: Request) -> Result<(String, Bytes), ProviderError> {
109 let client = self.client.clone();
110 let url = request.url.to_string();
111 let result = Self::fetch_inner(client, request).await;
112 if let Err(e) = &result {
113 eprintln!("Error fetching {url}: {e:?}");
114 } else {
115 println!("Success {url}");
116 }
117 result
118 }
119}
120
121impl<D: 'static> NetProvider<D> for Provider<D> {
122 fn fetch(&self, doc_id: usize, request: Request, handler: BoxedHandler<D>) {
123 let client = self.client.clone();
124 let callback = Arc::clone(&self.resource_callback);
125 println!("Fetching {}", &request.url);
126 self.rt.spawn(async move {
127 let url = request.url.to_string();
128 let res = Self::fetch_with_handler(client, doc_id, request, handler, callback).await;
129 if let Err(e) = res {
130 eprintln!("Error fetching {url}: {e:?}");
131 } else {
132 println!("Success {url}");
133 }
134 });
135 }
136}
137
138#[derive(Debug)]
139pub enum ProviderError {
140 Io(std::io::Error),
141 DataUrl(data_url::DataUrlError),
142 DataUrlBase64(data_url::forgiving_base64::InvalidBase64),
143 ReqwestError(reqwest::Error),
144}
145
146impl From<std::io::Error> for ProviderError {
147 fn from(value: std::io::Error) -> Self {
148 Self::Io(value)
149 }
150}
151
152impl From<data_url::DataUrlError> for ProviderError {
153 fn from(value: data_url::DataUrlError) -> Self {
154 Self::DataUrl(value)
155 }
156}
157
158impl From<data_url::forgiving_base64::InvalidBase64> for ProviderError {
159 fn from(value: data_url::forgiving_base64::InvalidBase64) -> Self {
160 Self::DataUrlBase64(value)
161 }
162}
163
164impl From<reqwest::Error> for ProviderError {
165 fn from(value: reqwest::Error) -> Self {
166 Self::ReqwestError(value)
167 }
168}
169
170pub struct MpscCallback<T>(UnboundedSender<(usize, T)>);
171impl<T> MpscCallback<T> {
172 pub fn new() -> (UnboundedReceiver<(usize, T)>, Self) {
173 let (send, recv) = unbounded_channel();
174 (recv, Self(send))
175 }
176}
177impl<T: Send + Sync + 'static> NetCallback<T> for MpscCallback<T> {
178 fn call(&self, doc_id: usize, result: Result<T, Option<String>>) {
179 if let Ok(data) = result {
181 let _ = self.0.send((doc_id, data));
182 }
183 }
184}
185
186trait ReqwestExt {
187 async fn apply_body(self, body: Body, content_type: &str) -> Self;
188}
189impl ReqwestExt for reqwest::RequestBuilder {
190 async fn apply_body(self, body: Body, content_type: &str) -> Self {
191 match body {
192 Body::Bytes(bytes) => self.body(bytes),
193 Body::Form(form_data) => match content_type {
194 "application/x-www-form-urlencoded" => self.form(&form_data),
195 #[cfg(feature = "multipart")]
196 "multipart/form-data" => {
197 use blitz_traits::net::Entry;
198 use blitz_traits::net::EntryValue;
199 let mut form_data = form_data;
200 let mut form = reqwest::multipart::Form::new();
201 for Entry { name, value } in form_data.0.drain(..) {
202 form = match value {
203 EntryValue::String(value) => form.text(name, value),
204 EntryValue::File(path_buf) => form
205 .file(name, path_buf)
206 .await
207 .expect("Couldn't read form file from disk"),
208 EntryValue::EmptyFile => form.part(
209 name,
210 reqwest::multipart::Part::bytes(&[])
211 .mime_str("application/octet-stream")
212 .unwrap(),
213 ),
214 };
215 }
216 self.multipart(form)
217 }
218 _ => self,
219 },
220 Body::Empty => self,
221 }
222 }
223}