blitz_net/
lib.rs

1//! Networking (HTTP, filesystem, Data URIs) for Blitz
2//!
3//! Provides an implementation of the [`blitz_traits::net::NetProvider`] trait.
4
5use blitz_traits::net::{
6    Body, BoxedHandler, Bytes, NetCallback, NetProvider, Request, SharedCallback,
7};
8use data_url::DataUrl;
9use reqwest::Client;
10use std::sync::Arc;
11use tokio::{
12    runtime::Handle,
13    sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
14};
15
16const USER_AGENT: &str = "Mozilla/5.0 (X11; Linux x86_64; rv:60.0) Gecko/20100101 Firefox/81.0";
17
18pub struct Provider<D> {
19    rt: Handle,
20    client: Client,
21    resource_callback: SharedCallback<D>,
22}
23impl<D: 'static> Provider<D> {
24    pub fn new(resource_callback: SharedCallback<D>) -> Self {
25        #[cfg(feature = "cookies")]
26        let client = Client::builder().cookie_store(true).build().unwrap();
27        #[cfg(not(feature = "cookies"))]
28        let client = Client::new();
29
30        Self {
31            rt: Handle::current(),
32            client,
33            resource_callback,
34        }
35    }
36    pub fn shared(res_callback: SharedCallback<D>) -> Arc<dyn NetProvider<D>> {
37        Arc::new(Self::new(res_callback))
38    }
39    pub fn is_empty(&self) -> bool {
40        Arc::strong_count(&self.resource_callback) == 1
41    }
42    pub fn count(&self) -> usize {
43        Arc::strong_count(&self.resource_callback) - 1
44    }
45}
46impl<D: 'static> Provider<D> {
47    async fn fetch_inner(
48        client: Client,
49        request: Request,
50    ) -> Result<(String, Bytes), ProviderError> {
51        Ok(match request.url.scheme() {
52            "data" => {
53                let data_url = DataUrl::process(request.url.as_str())?;
54                let decoded = data_url.decode_to_vec()?;
55                (request.url.to_string(), Bytes::from(decoded.0))
56            }
57            "file" => {
58                let file_content = std::fs::read(request.url.path())?;
59                (request.url.to_string(), Bytes::from(file_content))
60            }
61            _ => {
62                let response = client
63                    .request(request.method, request.url)
64                    .headers(request.headers)
65                    .header("Content-Type", request.content_type.as_str())
66                    .header("User-Agent", USER_AGENT)
67                    .apply_body(request.body, request.content_type.as_str())
68                    .await
69                    .send()
70                    .await?;
71
72                (response.url().to_string(), response.bytes().await?)
73            }
74        })
75    }
76
77    async fn fetch_with_handler(
78        client: Client,
79        doc_id: usize,
80        request: Request,
81        handler: BoxedHandler<D>,
82        res_callback: SharedCallback<D>,
83    ) -> Result<(), ProviderError> {
84        let (_response_url, bytes) = Self::fetch_inner(client, request).await?;
85        handler.bytes(doc_id, bytes, res_callback);
86        Ok(())
87    }
88
89    #[allow(clippy::type_complexity)]
90    pub fn fetch_with_callback(
91        &self,
92        request: Request,
93        callback: Box<dyn FnOnce(Result<(String, Bytes), ProviderError>) + Send + Sync + 'static>,
94    ) {
95        #[cfg(feature = "debug_log")]
96        let url = request.url.to_string();
97
98        let client = self.client.clone();
99        self.rt.spawn(async move {
100            let result = Self::fetch_inner(client, request).await;
101
102            #[cfg(feature = "debug_log")]
103            if let Err(e) = &result {
104                eprintln!("Error fetching {url}: {e:?}");
105            } else {
106                println!("Success {url}");
107            }
108
109            callback(result);
110        });
111    }
112
113    pub async fn fetch_async(&self, request: Request) -> Result<(String, Bytes), ProviderError> {
114        #[cfg(feature = "debug_log")]
115        let url = request.url.to_string();
116
117        let client = self.client.clone();
118        let result = Self::fetch_inner(client, request).await;
119
120        #[cfg(feature = "debug_log")]
121        if let Err(e) = &result {
122            eprintln!("Error fetching {url}: {e:?}");
123        } else {
124            println!("Success {url}");
125        }
126
127        result
128    }
129}
130
131impl<D: 'static> NetProvider<D> for Provider<D> {
132    fn fetch(&self, doc_id: usize, request: Request, handler: BoxedHandler<D>) {
133        let client = self.client.clone();
134        let callback = Arc::clone(&self.resource_callback);
135
136        #[cfg(feature = "debug_log")]
137        println!("Fetching {}", &request.url);
138
139        self.rt.spawn(async move {
140            #[cfg(feature = "debug_log")]
141            let url = request.url.to_string();
142
143            let _res = Self::fetch_with_handler(client, doc_id, request, handler, callback).await;
144
145            #[cfg(feature = "debug_log")]
146            if let Err(e) = _res {
147                eprintln!("Error fetching {url}: {e:?}");
148            } else {
149                println!("Success {url}");
150            }
151        });
152    }
153}
154
155#[derive(Debug)]
156pub enum ProviderError {
157    Io(std::io::Error),
158    DataUrl(data_url::DataUrlError),
159    DataUrlBase64(data_url::forgiving_base64::InvalidBase64),
160    ReqwestError(reqwest::Error),
161}
162
163impl From<std::io::Error> for ProviderError {
164    fn from(value: std::io::Error) -> Self {
165        Self::Io(value)
166    }
167}
168
169impl From<data_url::DataUrlError> for ProviderError {
170    fn from(value: data_url::DataUrlError) -> Self {
171        Self::DataUrl(value)
172    }
173}
174
175impl From<data_url::forgiving_base64::InvalidBase64> for ProviderError {
176    fn from(value: data_url::forgiving_base64::InvalidBase64) -> Self {
177        Self::DataUrlBase64(value)
178    }
179}
180
181impl From<reqwest::Error> for ProviderError {
182    fn from(value: reqwest::Error) -> Self {
183        Self::ReqwestError(value)
184    }
185}
186
187pub struct MpscCallback<T>(UnboundedSender<(usize, T)>);
188impl<T> MpscCallback<T> {
189    pub fn new() -> (UnboundedReceiver<(usize, T)>, Self) {
190        let (send, recv) = unbounded_channel();
191        (recv, Self(send))
192    }
193}
194impl<T: Send + Sync + 'static> NetCallback<T> for MpscCallback<T> {
195    fn call(&self, doc_id: usize, result: Result<T, Option<String>>) {
196        // TODO: handle error case
197        if let Ok(data) = result {
198            let _ = self.0.send((doc_id, data));
199        }
200    }
201}
202
203trait ReqwestExt {
204    async fn apply_body(self, body: Body, content_type: &str) -> Self;
205}
206impl ReqwestExt for reqwest::RequestBuilder {
207    async fn apply_body(self, body: Body, content_type: &str) -> Self {
208        match body {
209            Body::Bytes(bytes) => self.body(bytes),
210            Body::Form(form_data) => match content_type {
211                "application/x-www-form-urlencoded" => self.form(&form_data),
212                #[cfg(feature = "multipart")]
213                "multipart/form-data" => {
214                    use blitz_traits::net::Entry;
215                    use blitz_traits::net::EntryValue;
216                    let mut form_data = form_data;
217                    let mut form = reqwest::multipart::Form::new();
218                    for Entry { name, value } in form_data.0.drain(..) {
219                        form = match value {
220                            EntryValue::String(value) => form.text(name, value),
221                            EntryValue::File(path_buf) => form
222                                .file(name, path_buf)
223                                .await
224                                .expect("Couldn't read form file from disk"),
225                            EntryValue::EmptyFile => form.part(
226                                name,
227                                reqwest::multipart::Part::bytes(&[])
228                                    .mime_str("application/octet-stream")
229                                    .unwrap(),
230                            ),
231                        };
232                    }
233                    self.multipart(form)
234                }
235                _ => self,
236            },
237            Body::Empty => self,
238        }
239    }
240}