Skip to main content

cranpose_services/
http.rs

1use cranpose_core::{compositionLocalOf, CompositionLocal};
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6#[derive(thiserror::Error, Debug, Clone)]
7pub enum HttpError {
8    #[error("Failed to build HTTP client: {0}")]
9    ClientInit(String),
10    #[error("Request failed for {url}: {message}")]
11    RequestFailed { url: String, message: String },
12    #[error("Request failed with status {status} for {url}")]
13    HttpStatus { url: String, status: u16 },
14    #[error("Failed to read response body for {url}: {message}")]
15    BodyReadFailed { url: String, message: String },
16    #[error("Invalid response for {url}: {message}")]
17    InvalidResponse { url: String, message: String },
18    #[error("No window object available")]
19    NoWindow,
20}
21
22#[cfg(not(target_arch = "wasm32"))]
23pub type HttpFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, HttpError>> + Send + 'a>>;
24
25#[cfg(target_arch = "wasm32")]
26pub type HttpFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, HttpError>> + 'a>>;
27
28pub trait HttpClient: Send + Sync {
29    fn get_text<'a>(&'a self, url: &'a str) -> HttpFuture<'a, String>;
30
31    fn get_bytes<'a>(&'a self, url: &'a str) -> HttpFuture<'a, Vec<u8>> {
32        Box::pin(async move { self.get_text(url).await.map(|text| text.into_bytes()) })
33    }
34}
35
36pub type HttpClientRef = Arc<dyn HttpClient>;
37
38struct DefaultHttpClient;
39
40impl HttpClient for DefaultHttpClient {
41    fn get_text<'a>(&'a self, url: &'a str) -> HttpFuture<'a, String> {
42        Box::pin(async move {
43            #[cfg(not(target_arch = "wasm32"))]
44            {
45                fetch_text_native(url)
46            }
47
48            #[cfg(target_arch = "wasm32")]
49            {
50                fetch_text_web(url).await
51            }
52        })
53    }
54
55    fn get_bytes<'a>(&'a self, url: &'a str) -> HttpFuture<'a, Vec<u8>> {
56        Box::pin(async move {
57            #[cfg(not(target_arch = "wasm32"))]
58            {
59                fetch_bytes_native(url)
60            }
61
62            #[cfg(target_arch = "wasm32")]
63            {
64                fetch_bytes_web(url).await
65            }
66        })
67    }
68}
69
70#[cfg(not(target_arch = "wasm32"))]
71fn fetch_text_native(url: &str) -> Result<String, HttpError> {
72    native_response(url)?
73        .text()
74        .map_err(|err| HttpError::BodyReadFailed {
75            url: url.to_string(),
76            message: err.to_string(),
77        })
78}
79
80#[cfg(not(target_arch = "wasm32"))]
81fn fetch_bytes_native(url: &str) -> Result<Vec<u8>, HttpError> {
82    native_response(url)?
83        .bytes()
84        .map(|bytes| bytes.to_vec())
85        .map_err(|err| HttpError::BodyReadFailed {
86            url: url.to_string(),
87            message: err.to_string(),
88        })
89}
90
91#[cfg(not(target_arch = "wasm32"))]
92fn native_response(url: &str) -> Result<reqwest::blocking::Response, HttpError> {
93    let response = native_client()?
94        .get(url)
95        .send()
96        .map_err(|err| HttpError::RequestFailed {
97            url: url.to_string(),
98            message: err.to_string(),
99        })?;
100
101    let status = response.status();
102    if !status.is_success() {
103        return Err(HttpError::HttpStatus {
104            url: url.to_string(),
105            status: status.as_u16(),
106        });
107    }
108
109    Ok(response)
110}
111
112#[cfg(not(target_arch = "wasm32"))]
113fn native_client() -> Result<&'static reqwest::blocking::Client, HttpError> {
114    use std::sync::OnceLock;
115
116    static CLIENT: OnceLock<Result<reqwest::blocking::Client, HttpError>> = OnceLock::new();
117    CLIENT
118        .get_or_init(build_native_client)
119        .as_ref()
120        .map_err(Clone::clone)
121}
122
123#[cfg(not(target_arch = "wasm32"))]
124fn build_native_client() -> Result<reqwest::blocking::Client, HttpError> {
125    use std::time::Duration;
126
127    reqwest::blocking::Client::builder()
128        .timeout(Duration::from_secs(10))
129        .user_agent("cranpose/0.1")
130        .build()
131        .map_err(|err| HttpError::ClientInit(err.to_string()))
132}
133
134#[cfg(target_arch = "wasm32")]
135async fn fetch_text_web(url: &str) -> Result<String, HttpError> {
136    use wasm_bindgen::JsCast;
137    use wasm_bindgen_futures::JsFuture;
138    use web_sys::{Request, RequestInit, RequestMode, Response};
139
140    let opts = RequestInit::new();
141    opts.set_method("GET");
142    opts.set_mode(RequestMode::Cors);
143
144    let request =
145        Request::new_with_str_and_init(url, &opts).map_err(|err| HttpError::RequestFailed {
146            url: url.to_string(),
147            message: format!("{:?}", err),
148        })?;
149
150    let window = web_sys::window().ok_or(HttpError::NoWindow)?;
151    let resp_value = JsFuture::from(window.fetch_with_request(&request))
152        .await
153        .map_err(|err| HttpError::RequestFailed {
154            url: url.to_string(),
155            message: format!("{:?}", err),
156        })?;
157
158    let resp: Response = resp_value
159        .dyn_into()
160        .map_err(|_| HttpError::InvalidResponse {
161            url: url.to_string(),
162            message: "Response is not a Response object".to_string(),
163        })?;
164
165    if !resp.ok() {
166        return Err(HttpError::HttpStatus {
167            url: url.to_string(),
168            status: resp.status(),
169        });
170    }
171
172    let text_promise = resp.text().map_err(|err| HttpError::BodyReadFailed {
173        url: url.to_string(),
174        message: format!("{:?}", err),
175    })?;
176    let text_value =
177        JsFuture::from(text_promise)
178            .await
179            .map_err(|err| HttpError::BodyReadFailed {
180                url: url.to_string(),
181                message: format!("{:?}", err),
182            })?;
183
184    text_value
185        .as_string()
186        .ok_or_else(|| HttpError::InvalidResponse {
187            url: url.to_string(),
188            message: "Response body is not a string".to_string(),
189        })
190}
191
192#[cfg(target_arch = "wasm32")]
193async fn fetch_bytes_web(url: &str) -> Result<Vec<u8>, HttpError> {
194    use wasm_bindgen::JsCast;
195    use wasm_bindgen_futures::JsFuture;
196    use web_sys::{Request, RequestInit, RequestMode, Response};
197
198    let opts = RequestInit::new();
199    opts.set_method("GET");
200    opts.set_mode(RequestMode::Cors);
201
202    let request =
203        Request::new_with_str_and_init(url, &opts).map_err(|err| HttpError::RequestFailed {
204            url: url.to_string(),
205            message: format!("{:?}", err),
206        })?;
207
208    let window = web_sys::window().ok_or(HttpError::NoWindow)?;
209    let resp_value = JsFuture::from(window.fetch_with_request(&request))
210        .await
211        .map_err(|err| HttpError::RequestFailed {
212            url: url.to_string(),
213            message: format!("{:?}", err),
214        })?;
215
216    let resp: Response = resp_value
217        .dyn_into()
218        .map_err(|_| HttpError::InvalidResponse {
219            url: url.to_string(),
220            message: "Response is not a Response object".to_string(),
221        })?;
222
223    if !resp.ok() {
224        return Err(HttpError::HttpStatus {
225            url: url.to_string(),
226            status: resp.status(),
227        });
228    }
229
230    let bytes_promise = resp
231        .array_buffer()
232        .map_err(|err| HttpError::BodyReadFailed {
233            url: url.to_string(),
234            message: format!("{:?}", err),
235        })?;
236    let bytes_value =
237        JsFuture::from(bytes_promise)
238            .await
239            .map_err(|err| HttpError::BodyReadFailed {
240                url: url.to_string(),
241                message: format!("{:?}", err),
242            })?;
243
244    let array = js_sys::Uint8Array::new(&bytes_value);
245    Ok(array.to_vec())
246}
247
248pub fn default_http_client() -> HttpClientRef {
249    Arc::new(DefaultHttpClient)
250}
251
252pub fn local_http_client() -> CompositionLocal<HttpClientRef> {
253    thread_local! {
254        static LOCAL_HTTP_CLIENT: std::cell::RefCell<Option<CompositionLocal<HttpClientRef>>> = const { std::cell::RefCell::new(None) };
255    }
256
257    LOCAL_HTTP_CLIENT.with(|cell| {
258        let mut local = cell.borrow_mut();
259        if local.is_none() {
260            *local = Some(compositionLocalOf(default_http_client));
261        }
262        local
263            .as_ref()
264            .expect("HTTP client composition local must be initialized")
265            .clone()
266    })
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::run_test_composition;
273    use cranpose_core::CompositionLocalProvider;
274    use std::cell::RefCell;
275    use std::rc::Rc;
276    #[cfg(not(target_arch = "wasm32"))]
277    use std::thread;
278
279    struct TestHttpClient;
280
281    impl HttpClient for TestHttpClient {
282        fn get_text<'a>(&'a self, _url: &'a str) -> HttpFuture<'a, String> {
283            Box::pin(async { Ok("ok".to_string()) })
284        }
285    }
286
287    #[test]
288    fn default_http_client_is_available() {
289        let client = default_http_client();
290        let cloned = client.clone();
291        assert_eq!(Arc::strong_count(&client), 2);
292        drop(cloned);
293        assert_eq!(Arc::strong_count(&client), 1);
294    }
295
296    #[test]
297    fn test_client_uses_default_get_bytes_from_text() {
298        let client = TestHttpClient;
299        let bytes = pollster::block_on(client.get_bytes("https://example.com")).expect("bytes");
300        assert_eq!(bytes, b"ok".to_vec());
301    }
302
303    #[test]
304    fn local_http_client_can_be_overridden() {
305        let local = local_http_client();
306        let default_client = default_http_client();
307        let custom_client: HttpClientRef = Arc::new(TestHttpClient);
308        let captured = Rc::new(RefCell::new(None));
309
310        {
311            let captured_for_closure = Rc::clone(&captured);
312            let custom_client = custom_client.clone();
313            let local_for_provider = local.clone();
314            let local_for_read = local.clone();
315            run_test_composition(move || {
316                let captured = Rc::clone(&captured_for_closure);
317                let local_for_read = local_for_read.clone();
318                CompositionLocalProvider(
319                    vec![local_for_provider.provides(custom_client.clone())],
320                    move || {
321                        let current = local_for_read.current();
322                        *captured.borrow_mut() = Some(current);
323                    },
324                );
325            });
326        }
327
328        let current = captured.borrow().as_ref().expect("client captured").clone();
329        assert!(Arc::ptr_eq(&current, &custom_client));
330        assert!(!Arc::ptr_eq(&current, &default_client));
331    }
332
333    #[cfg(not(target_arch = "wasm32"))]
334    #[test]
335    fn native_http_client_builds() {
336        build_native_client().expect("native HTTP client should initialize");
337    }
338
339    #[cfg(not(target_arch = "wasm32"))]
340    #[test]
341    fn default_http_client_fetches_text_from_local_server() {
342        use std::io::{Read, Write};
343        use std::net::TcpListener;
344
345        let listener = TcpListener::bind("127.0.0.1:0").expect("bind local test server");
346        let address = listener
347            .local_addr()
348            .expect("read local test server address");
349        let server = thread::spawn(move || {
350            let (mut stream, _) = listener.accept().expect("accept local test request");
351            let mut request = [0_u8; 1024];
352            let _ = stream.read(&mut request).expect("read local test request");
353            let body = "cranpose-http-test";
354            write!(
355                stream,
356                "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
357                body.len(),
358                body
359            )
360            .expect("write local test response");
361        });
362
363        let url = format!("http://{address}");
364        let text = pollster::block_on(default_http_client().get_text(&url))
365            .expect("fetch text from local test server");
366        server.join().expect("join local test server");
367
368        assert_eq!(text, "cranpose-http-test");
369    }
370}