blitz_net/
lib.rs

1//! Networking (HTTP, filesystem, Data URIs) for Blitz
2//!
3//! Provides an implementation of the [`blitz_traits::net::NetProvider`] trait.
4
5use blitz_traits::net::{
6    Body, BoxedHandler, Bytes, NetCallback, NetProvider, Request, SharedCallback,
7};
8use data_url::DataUrl;
9use reqwest::Client;
10use std::sync::Arc;
11use tokio::{
12    runtime::Handle,
13    sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
14};
15
16const USER_AGENT: &str = "Mozilla/5.0 (X11; Linux x86_64; rv:60.0) Gecko/20100101 Firefox/81.0";
17
18pub struct Provider<D> {
19    rt: Handle,
20    client: Client,
21    resource_callback: SharedCallback<D>,
22}
23impl<D: 'static> Provider<D> {
24    pub fn new(resource_callback: SharedCallback<D>) -> Self {
25        #[cfg(feature = "cookies")]
26        let client = Client::builder().cookie_store(true).build().unwrap();
27        #[cfg(not(feature = "cookies"))]
28        let client = Client::new();
29
30        Self {
31            rt: Handle::current(),
32            client,
33            resource_callback,
34        }
35    }
36    pub fn shared(res_callback: SharedCallback<D>) -> Arc<dyn NetProvider<D>> {
37        Arc::new(Self::new(res_callback))
38    }
39    pub fn is_empty(&self) -> bool {
40        Arc::strong_count(&self.resource_callback) == 1
41    }
42    pub fn count(&self) -> usize {
43        Arc::strong_count(&self.resource_callback) - 1
44    }
45}
46impl<D: 'static> Provider<D> {
47    async fn fetch_inner(
48        client: Client,
49        request: Request,
50    ) -> Result<(String, Bytes), ProviderError> {
51        Ok(match request.url.scheme() {
52            "data" => {
53                let data_url = DataUrl::process(request.url.as_str())?;
54                let decoded = data_url.decode_to_vec()?;
55                (request.url.to_string(), Bytes::from(decoded.0))
56            }
57            "file" => {
58                let file_content = std::fs::read(request.url.path())?;
59                (request.url.to_string(), Bytes::from(file_content))
60            }
61            _ => {
62                let response = client
63                    .request(request.method, request.url)
64                    .headers(request.headers)
65                    .header("Content-Type", request.content_type.as_str())
66                    .header("User-Agent", USER_AGENT)
67                    .apply_body(request.body, request.content_type.as_str())
68                    .await
69                    .send()
70                    .await?;
71
72                (response.url().to_string(), response.bytes().await?)
73            }
74        })
75    }
76
77    async fn fetch_with_handler(
78        client: Client,
79        doc_id: usize,
80        request: Request,
81        handler: BoxedHandler<D>,
82        res_callback: SharedCallback<D>,
83    ) -> Result<(), ProviderError> {
84        let (_response_url, bytes) = Self::fetch_inner(client, request).await?;
85        handler.bytes(doc_id, bytes, res_callback);
86        Ok(())
87    }
88
89    #[allow(clippy::type_complexity)]
90    pub fn fetch_with_callback(
91        &self,
92        request: Request,
93        callback: Box<dyn FnOnce(Result<(String, Bytes), ProviderError>) + Send + Sync + 'static>,
94    ) {
95        let client = self.client.clone();
96        self.rt.spawn(async move {
97            let url = request.url.to_string();
98            let result = Self::fetch_inner(client, request).await;
99            if let Err(e) = &result {
100                eprintln!("Error fetching {url}: {e:?}");
101            } else {
102                println!("Success {url}");
103            }
104            callback(result);
105        });
106    }
107
108    pub async fn fetch_async(&self, request: Request) -> Result<(String, Bytes), ProviderError> {
109        let client = self.client.clone();
110        let url = request.url.to_string();
111        let result = Self::fetch_inner(client, request).await;
112        if let Err(e) = &result {
113            eprintln!("Error fetching {url}: {e:?}");
114        } else {
115            println!("Success {url}");
116        }
117        result
118    }
119}
120
121impl<D: 'static> NetProvider<D> for Provider<D> {
122    fn fetch(&self, doc_id: usize, request: Request, handler: BoxedHandler<D>) {
123        let client = self.client.clone();
124        let callback = Arc::clone(&self.resource_callback);
125        println!("Fetching {}", &request.url);
126        self.rt.spawn(async move {
127            let url = request.url.to_string();
128            let res = Self::fetch_with_handler(client, doc_id, request, handler, callback).await;
129            if let Err(e) = res {
130                eprintln!("Error fetching {url}: {e:?}");
131            } else {
132                println!("Success {url}");
133            }
134        });
135    }
136}
137
138#[derive(Debug)]
139pub enum ProviderError {
140    Io(std::io::Error),
141    DataUrl(data_url::DataUrlError),
142    DataUrlBase64(data_url::forgiving_base64::InvalidBase64),
143    ReqwestError(reqwest::Error),
144}
145
146impl From<std::io::Error> for ProviderError {
147    fn from(value: std::io::Error) -> Self {
148        Self::Io(value)
149    }
150}
151
152impl From<data_url::DataUrlError> for ProviderError {
153    fn from(value: data_url::DataUrlError) -> Self {
154        Self::DataUrl(value)
155    }
156}
157
158impl From<data_url::forgiving_base64::InvalidBase64> for ProviderError {
159    fn from(value: data_url::forgiving_base64::InvalidBase64) -> Self {
160        Self::DataUrlBase64(value)
161    }
162}
163
164impl From<reqwest::Error> for ProviderError {
165    fn from(value: reqwest::Error) -> Self {
166        Self::ReqwestError(value)
167    }
168}
169
170pub struct MpscCallback<T>(UnboundedSender<(usize, T)>);
171impl<T> MpscCallback<T> {
172    pub fn new() -> (UnboundedReceiver<(usize, T)>, Self) {
173        let (send, recv) = unbounded_channel();
174        (recv, Self(send))
175    }
176}
177impl<T: Send + Sync + 'static> NetCallback<T> for MpscCallback<T> {
178    fn call(&self, doc_id: usize, result: Result<T, Option<String>>) {
179        // TODO: handle error case
180        if let Ok(data) = result {
181            let _ = self.0.send((doc_id, data));
182        }
183    }
184}
185
186trait ReqwestExt {
187    async fn apply_body(self, body: Body, content_type: &str) -> Self;
188}
189impl ReqwestExt for reqwest::RequestBuilder {
190    async fn apply_body(self, body: Body, content_type: &str) -> Self {
191        match body {
192            Body::Bytes(bytes) => self.body(bytes),
193            Body::Form(form_data) => match content_type {
194                "application/x-www-form-urlencoded" => self.form(&form_data),
195                #[cfg(feature = "multipart")]
196                "multipart/form-data" => {
197                    use blitz_traits::net::Entry;
198                    use blitz_traits::net::EntryValue;
199                    let mut form_data = form_data;
200                    let mut form = reqwest::multipart::Form::new();
201                    for Entry { name, value } in form_data.0.drain(..) {
202                        form = match value {
203                            EntryValue::String(value) => form.text(name, value),
204                            EntryValue::File(path_buf) => form
205                                .file(name, path_buf)
206                                .await
207                                .expect("Couldn't read form file from disk"),
208                            EntryValue::EmptyFile => form.part(
209                                name,
210                                reqwest::multipart::Part::bytes(&[])
211                                    .mime_str("application/octet-stream")
212                                    .unwrap(),
213                            ),
214                        };
215                    }
216                    self.multipart(form)
217                }
218                _ => self,
219            },
220            Body::Empty => self,
221        }
222    }
223}