Skip to main content

cranpose_services/
http.rs

1use cranpose_core::{compositionLocalOf, CompositionLocal};
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6#[derive(thiserror::Error, Debug, Clone)]
7pub enum HttpError {
8    #[error("Failed to build HTTP client: {0}")]
9    ClientInit(String),
10    #[error("Request failed for {url}: {message}")]
11    RequestFailed { url: String, message: String },
12    #[error("Request failed with status {status} for {url}")]
13    HttpStatus { url: String, status: u16 },
14    #[error("Failed to read response body for {url}: {message}")]
15    BodyReadFailed { url: String, message: String },
16    #[error("Invalid response for {url}: {message}")]
17    InvalidResponse { url: String, message: String },
18    #[error("No window object available")]
19    NoWindow,
20}
21
22#[cfg(not(target_arch = "wasm32"))]
23pub type HttpFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, HttpError>> + Send + 'a>>;
24
25#[cfg(target_arch = "wasm32")]
26pub type HttpFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, HttpError>> + 'a>>;
27
28pub trait HttpClient: Send + Sync {
29    fn get_text<'a>(&'a self, url: &'a str) -> HttpFuture<'a, String>;
30
31    fn get_bytes<'a>(&'a self, url: &'a str) -> HttpFuture<'a, Vec<u8>> {
32        Box::pin(async move { self.get_text(url).await.map(|text| text.into_bytes()) })
33    }
34}
35
36pub type HttpClientRef = Arc<dyn HttpClient>;
37
38struct DefaultHttpClient;
39
40impl HttpClient for DefaultHttpClient {
41    fn get_text<'a>(&'a self, url: &'a str) -> HttpFuture<'a, String> {
42        Box::pin(async move {
43            #[cfg(not(target_arch = "wasm32"))]
44            {
45                fetch_text_native(url)
46            }
47
48            #[cfg(target_arch = "wasm32")]
49            {
50                fetch_text_web(url).await
51            }
52        })
53    }
54
55    fn get_bytes<'a>(&'a self, url: &'a str) -> HttpFuture<'a, Vec<u8>> {
56        Box::pin(async move {
57            #[cfg(not(target_arch = "wasm32"))]
58            {
59                fetch_bytes_native(url)
60            }
61
62            #[cfg(target_arch = "wasm32")]
63            {
64                fetch_bytes_web(url).await
65            }
66        })
67    }
68}
69
70#[cfg(not(target_arch = "wasm32"))]
71fn fetch_text_native(url: &str) -> Result<String, HttpError> {
72    use std::sync::OnceLock;
73    use std::time::Duration;
74
75    static CLIENT: OnceLock<Result<reqwest::blocking::Client, HttpError>> = OnceLock::new();
76    let client = CLIENT
77        .get_or_init(|| {
78            reqwest::blocking::Client::builder()
79                .timeout(Duration::from_secs(10))
80                .user_agent("cranpose/0.1")
81                .build()
82                .map_err(|err| HttpError::ClientInit(err.to_string()))
83        })
84        .as_ref()
85        .map_err(|err| err.clone())?;
86
87    let response = client
88        .get(url)
89        .send()
90        .map_err(|err| HttpError::RequestFailed {
91            url: url.to_string(),
92            message: err.to_string(),
93        })?;
94
95    let status = response.status();
96    if !status.is_success() {
97        return Err(HttpError::HttpStatus {
98            url: url.to_string(),
99            status: status.as_u16(),
100        });
101    }
102
103    response.text().map_err(|err| HttpError::BodyReadFailed {
104        url: url.to_string(),
105        message: err.to_string(),
106    })
107}
108
109#[cfg(not(target_arch = "wasm32"))]
110fn fetch_bytes_native(url: &str) -> Result<Vec<u8>, HttpError> {
111    use std::sync::OnceLock;
112    use std::time::Duration;
113
114    static CLIENT: OnceLock<Result<reqwest::blocking::Client, HttpError>> = OnceLock::new();
115    let client = CLIENT
116        .get_or_init(|| {
117            reqwest::blocking::Client::builder()
118                .timeout(Duration::from_secs(10))
119                .user_agent("cranpose/0.1")
120                .build()
121                .map_err(|err| HttpError::ClientInit(err.to_string()))
122        })
123        .as_ref()
124        .map_err(|err| err.clone())?;
125
126    let response = client
127        .get(url)
128        .send()
129        .map_err(|err| HttpError::RequestFailed {
130            url: url.to_string(),
131            message: err.to_string(),
132        })?;
133
134    let status = response.status();
135    if !status.is_success() {
136        return Err(HttpError::HttpStatus {
137            url: url.to_string(),
138            status: status.as_u16(),
139        });
140    }
141
142    response
143        .bytes()
144        .map(|bytes| bytes.to_vec())
145        .map_err(|err| HttpError::BodyReadFailed {
146            url: url.to_string(),
147            message: err.to_string(),
148        })
149}
150
151#[cfg(target_arch = "wasm32")]
152async fn fetch_text_web(url: &str) -> Result<String, HttpError> {
153    use wasm_bindgen::JsCast;
154    use wasm_bindgen_futures::JsFuture;
155    use web_sys::{Request, RequestInit, RequestMode, Response};
156
157    let opts = RequestInit::new();
158    opts.set_method("GET");
159    opts.set_mode(RequestMode::Cors);
160
161    let request =
162        Request::new_with_str_and_init(url, &opts).map_err(|err| HttpError::RequestFailed {
163            url: url.to_string(),
164            message: format!("{:?}", err),
165        })?;
166
167    let window = web_sys::window().ok_or(HttpError::NoWindow)?;
168    let resp_value = JsFuture::from(window.fetch_with_request(&request))
169        .await
170        .map_err(|err| HttpError::RequestFailed {
171            url: url.to_string(),
172            message: format!("{:?}", err),
173        })?;
174
175    let resp: Response = resp_value
176        .dyn_into()
177        .map_err(|_| HttpError::InvalidResponse {
178            url: url.to_string(),
179            message: "Response is not a Response object".to_string(),
180        })?;
181
182    if !resp.ok() {
183        return Err(HttpError::HttpStatus {
184            url: url.to_string(),
185            status: resp.status(),
186        });
187    }
188
189    let text_promise = resp.text().map_err(|err| HttpError::BodyReadFailed {
190        url: url.to_string(),
191        message: format!("{:?}", err),
192    })?;
193    let text_value =
194        JsFuture::from(text_promise)
195            .await
196            .map_err(|err| HttpError::BodyReadFailed {
197                url: url.to_string(),
198                message: format!("{:?}", err),
199            })?;
200
201    text_value
202        .as_string()
203        .ok_or_else(|| HttpError::InvalidResponse {
204            url: url.to_string(),
205            message: "Response body is not a string".to_string(),
206        })
207}
208
209#[cfg(target_arch = "wasm32")]
210async fn fetch_bytes_web(url: &str) -> Result<Vec<u8>, HttpError> {
211    use wasm_bindgen::JsCast;
212    use wasm_bindgen_futures::JsFuture;
213    use web_sys::{Request, RequestInit, RequestMode, Response};
214
215    let opts = RequestInit::new();
216    opts.set_method("GET");
217    opts.set_mode(RequestMode::Cors);
218
219    let request =
220        Request::new_with_str_and_init(url, &opts).map_err(|err| HttpError::RequestFailed {
221            url: url.to_string(),
222            message: format!("{:?}", err),
223        })?;
224
225    let window = web_sys::window().ok_or(HttpError::NoWindow)?;
226    let resp_value = JsFuture::from(window.fetch_with_request(&request))
227        .await
228        .map_err(|err| HttpError::RequestFailed {
229            url: url.to_string(),
230            message: format!("{:?}", err),
231        })?;
232
233    let resp: Response = resp_value
234        .dyn_into()
235        .map_err(|_| HttpError::InvalidResponse {
236            url: url.to_string(),
237            message: "Response is not a Response object".to_string(),
238        })?;
239
240    if !resp.ok() {
241        return Err(HttpError::HttpStatus {
242            url: url.to_string(),
243            status: resp.status(),
244        });
245    }
246
247    let bytes_promise = resp
248        .array_buffer()
249        .map_err(|err| HttpError::BodyReadFailed {
250            url: url.to_string(),
251            message: format!("{:?}", err),
252        })?;
253    let bytes_value =
254        JsFuture::from(bytes_promise)
255            .await
256            .map_err(|err| HttpError::BodyReadFailed {
257                url: url.to_string(),
258                message: format!("{:?}", err),
259            })?;
260
261    let array = js_sys::Uint8Array::new(&bytes_value);
262    Ok(array.to_vec())
263}
264
265pub fn default_http_client() -> HttpClientRef {
266    Arc::new(DefaultHttpClient)
267}
268
269pub fn local_http_client() -> CompositionLocal<HttpClientRef> {
270    thread_local! {
271        static LOCAL_HTTP_CLIENT: std::cell::RefCell<Option<CompositionLocal<HttpClientRef>>> = const { std::cell::RefCell::new(None) };
272    }
273
274    LOCAL_HTTP_CLIENT.with(|cell| {
275        let mut local = cell.borrow_mut();
276        if local.is_none() {
277            *local = Some(compositionLocalOf(default_http_client));
278        }
279        local
280            .as_ref()
281            .expect("HTTP client composition local must be initialized")
282            .clone()
283    })
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use crate::run_test_composition;
290    use cranpose_core::CompositionLocalProvider;
291    use std::cell::RefCell;
292    use std::rc::Rc;
293
294    struct TestHttpClient;
295
296    impl HttpClient for TestHttpClient {
297        fn get_text<'a>(&'a self, _url: &'a str) -> HttpFuture<'a, String> {
298            Box::pin(async { Ok("ok".to_string()) })
299        }
300    }
301
302    #[test]
303    fn default_http_client_is_available() {
304        let client = default_http_client();
305        let cloned = client.clone();
306        assert_eq!(Arc::strong_count(&client), 2);
307        drop(cloned);
308        assert_eq!(Arc::strong_count(&client), 1);
309    }
310
311    #[test]
312    fn test_client_uses_default_get_bytes_from_text() {
313        let client = TestHttpClient;
314        let bytes = pollster::block_on(client.get_bytes("https://example.com")).expect("bytes");
315        assert_eq!(bytes, b"ok".to_vec());
316    }
317
318    #[test]
319    fn local_http_client_can_be_overridden() {
320        let local = local_http_client();
321        let default_client = default_http_client();
322        let custom_client: HttpClientRef = Arc::new(TestHttpClient);
323        let captured = Rc::new(RefCell::new(None));
324
325        {
326            let captured_for_closure = Rc::clone(&captured);
327            let custom_client = custom_client.clone();
328            let local_for_provider = local.clone();
329            let local_for_read = local.clone();
330            run_test_composition(move || {
331                let captured = Rc::clone(&captured_for_closure);
332                let local_for_read = local_for_read.clone();
333                CompositionLocalProvider(
334                    vec![local_for_provider.provides(custom_client.clone())],
335                    move || {
336                        let current = local_for_read.current();
337                        *captured.borrow_mut() = Some(current);
338                    },
339                );
340            });
341        }
342
343        let current = captured.borrow().as_ref().expect("client captured").clone();
344        assert!(Arc::ptr_eq(&current, &custom_client));
345        assert!(!Arc::ptr_eq(&current, &default_client));
346    }
347}