Skip to main content

cranpose_services/
http.rs

1use cranpose_core::{compositionLocalOf, CompositionLocal};
2#[cfg(target_arch = "wasm32")]
3use futures_util::{stream, StreamExt};
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7
8#[derive(thiserror::Error, Debug, Clone)]
9pub enum HttpError {
10    #[error("Failed to build HTTP client: {0}")]
11    ClientInit(String),
12    #[error("Request failed for {url}: {message}")]
13    RequestFailed { url: String, message: String },
14    #[error("Request failed with status {status} for {url}")]
15    HttpStatus { url: String, status: u16 },
16    #[error("Failed to read response body for {url}: {message}")]
17    BodyReadFailed { url: String, message: String },
18    #[error("Invalid response for {url}: {message}")]
19    InvalidResponse { url: String, message: String },
20    #[error("No window object available")]
21    NoWindow,
22}
23
24#[cfg(not(target_arch = "wasm32"))]
25pub type HttpFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, HttpError>> + Send + 'a>>;
26
27#[cfg(target_arch = "wasm32")]
28pub type HttpFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, HttpError>> + 'a>>;
29
30pub trait HttpClient: Send + Sync {
31    fn get_text<'a>(&'a self, url: &'a str) -> HttpFuture<'a, String>;
32
33    fn get_bytes<'a>(&'a self, url: &'a str) -> HttpFuture<'a, Vec<u8>> {
34        Box::pin(async move { self.get_text(url).await.map(|text| text.into_bytes()) })
35    }
36}
37
38pub type HttpClientRef = Arc<dyn HttpClient>;
39
40#[cfg(not(target_arch = "wasm32"))]
41pub async fn map_ordered_concurrent<I, T, F, Fut>(
42    items: &[I],
43    concurrency: usize,
44    task: F,
45) -> Vec<T>
46where
47    I: Clone + Send,
48    T: Send,
49    F: Fn(I) -> Fut + Send + Sync + 'static,
50    Fut: Future<Output = T> + Send,
51{
52    let task = Arc::new(task);
53    let mut results = Vec::with_capacity(items.len());
54
55    for chunk in items.chunks(concurrency.max(1)) {
56        std::thread::scope(|scope| {
57            let mut handles = Vec::with_capacity(chunk.len());
58            for item in chunk.iter().cloned() {
59                let task = Arc::clone(&task);
60                handles.push(scope.spawn(move || pollster::block_on(task(item))));
61            }
62
63            for handle in handles {
64                results.push(
65                    handle
66                        .join()
67                        .unwrap_or_else(|_| panic!("ordered concurrent worker thread panicked")),
68                );
69            }
70        });
71    }
72
73    results
74}
75
76#[cfg(target_arch = "wasm32")]
77pub async fn map_ordered_concurrent<I, T, F, Fut>(
78    items: &[I],
79    concurrency: usize,
80    task: F,
81) -> Vec<T>
82where
83    I: Clone,
84    F: Fn(I) -> Fut + Clone,
85    Fut: Future<Output = T>,
86{
87    let mut results = stream::iter(items.iter().cloned().enumerate().map(|(index, item)| {
88        let task = task.clone();
89        async move { (index, task(item).await) }
90    }))
91    .buffer_unordered(concurrency.max(1))
92    .collect::<Vec<_>>()
93    .await;
94
95    results.sort_by_key(|(index, _)| *index);
96    results.into_iter().map(|(_, value)| value).collect()
97}
98
99struct DefaultHttpClient;
100
101impl HttpClient for DefaultHttpClient {
102    fn get_text<'a>(&'a self, url: &'a str) -> HttpFuture<'a, String> {
103        Box::pin(async move {
104            #[cfg(not(target_arch = "wasm32"))]
105            {
106                fetch_text_native(url)
107            }
108
109            #[cfg(target_arch = "wasm32")]
110            {
111                fetch_text_web(url).await
112            }
113        })
114    }
115
116    fn get_bytes<'a>(&'a self, url: &'a str) -> HttpFuture<'a, Vec<u8>> {
117        Box::pin(async move {
118            #[cfg(not(target_arch = "wasm32"))]
119            {
120                fetch_bytes_native(url)
121            }
122
123            #[cfg(target_arch = "wasm32")]
124            {
125                fetch_bytes_web(url).await
126            }
127        })
128    }
129}
130
131#[cfg(not(target_arch = "wasm32"))]
132fn fetch_text_native(url: &str) -> Result<String, HttpError> {
133    native_response(url)?
134        .text()
135        .map_err(|err| HttpError::BodyReadFailed {
136            url: url.to_string(),
137            message: err.to_string(),
138        })
139}
140
141#[cfg(not(target_arch = "wasm32"))]
142fn fetch_bytes_native(url: &str) -> Result<Vec<u8>, HttpError> {
143    native_response(url)?
144        .bytes()
145        .map(|bytes| bytes.to_vec())
146        .map_err(|err| HttpError::BodyReadFailed {
147            url: url.to_string(),
148            message: err.to_string(),
149        })
150}
151
152#[cfg(not(target_arch = "wasm32"))]
153fn native_response(url: &str) -> Result<reqwest::blocking::Response, HttpError> {
154    let response = native_client()?
155        .get(url)
156        .send()
157        .map_err(|err| HttpError::RequestFailed {
158            url: url.to_string(),
159            message: err.to_string(),
160        })?;
161
162    let status = response.status();
163    if !status.is_success() {
164        return Err(HttpError::HttpStatus {
165            url: url.to_string(),
166            status: status.as_u16(),
167        });
168    }
169
170    Ok(response)
171}
172
173#[cfg(not(target_arch = "wasm32"))]
174fn native_client() -> Result<&'static reqwest::blocking::Client, HttpError> {
175    use std::sync::OnceLock;
176
177    static CLIENT: OnceLock<Result<reqwest::blocking::Client, HttpError>> = OnceLock::new();
178    CLIENT
179        .get_or_init(build_native_client)
180        .as_ref()
181        .map_err(Clone::clone)
182}
183
184#[cfg(not(target_arch = "wasm32"))]
185fn build_native_client() -> Result<reqwest::blocking::Client, HttpError> {
186    use std::time::Duration;
187
188    configure_native_client_builder(
189        reqwest::blocking::Client::builder()
190            .timeout(Duration::from_secs(10))
191            .user_agent("cranpose/0.1"),
192    )?
193    .build()
194    .map_err(|err| HttpError::ClientInit(err.to_string()))
195}
196
197#[cfg(not(target_arch = "wasm32"))]
198fn configure_native_client_builder(
199    builder: reqwest::blocking::ClientBuilder,
200) -> Result<reqwest::blocking::ClientBuilder, HttpError> {
201    #[cfg(target_os = "android")]
202    {
203        return Ok(builder.tls_certs_only(android_root_certificates()?));
204    }
205
206    #[cfg(not(target_os = "android"))]
207    {
208        Ok(builder)
209    }
210}
211
212#[cfg(target_os = "android")]
213fn android_root_certificates() -> Result<Vec<reqwest::Certificate>, HttpError> {
214    certificates_from_der_chain(
215        webpki_root_certs::TLS_SERVER_ROOT_CERTS
216            .iter()
217            .map(|certificate| certificate.as_ref()),
218    )
219}
220
221#[cfg(any(test, target_os = "android"))]
222fn certificates_from_der_chain<'a, I>(
223    certificates: I,
224) -> Result<Vec<reqwest::Certificate>, HttpError>
225where
226    I: IntoIterator<Item = &'a [u8]>,
227{
228    certificates
229        .into_iter()
230        .enumerate()
231        .map(|(index, der)| {
232            reqwest::Certificate::from_der(der).map_err(|err| {
233                HttpError::ClientInit(format!(
234                    "Failed to load TLS root certificate {index}: {err}"
235                ))
236            })
237        })
238        .collect()
239}
240
241#[cfg(target_arch = "wasm32")]
242async fn fetch_text_web(url: &str) -> Result<String, HttpError> {
243    use wasm_bindgen::JsCast;
244    use wasm_bindgen_futures::JsFuture;
245    use web_sys::{Request, RequestInit, RequestMode, Response};
246
247    let opts = RequestInit::new();
248    opts.set_method("GET");
249    opts.set_mode(RequestMode::Cors);
250
251    let request =
252        Request::new_with_str_and_init(url, &opts).map_err(|err| HttpError::RequestFailed {
253            url: url.to_string(),
254            message: format!("{:?}", err),
255        })?;
256
257    let window = web_sys::window().ok_or(HttpError::NoWindow)?;
258    let resp_value = JsFuture::from(window.fetch_with_request(&request))
259        .await
260        .map_err(|err| HttpError::RequestFailed {
261            url: url.to_string(),
262            message: format!("{:?}", err),
263        })?;
264
265    let resp: Response = resp_value
266        .dyn_into()
267        .map_err(|_| HttpError::InvalidResponse {
268            url: url.to_string(),
269            message: "Response is not a Response object".to_string(),
270        })?;
271
272    if !resp.ok() {
273        return Err(HttpError::HttpStatus {
274            url: url.to_string(),
275            status: resp.status(),
276        });
277    }
278
279    let text_promise = resp.text().map_err(|err| HttpError::BodyReadFailed {
280        url: url.to_string(),
281        message: format!("{:?}", err),
282    })?;
283    let text_value =
284        JsFuture::from(text_promise)
285            .await
286            .map_err(|err| HttpError::BodyReadFailed {
287                url: url.to_string(),
288                message: format!("{:?}", err),
289            })?;
290
291    text_value
292        .as_string()
293        .ok_or_else(|| HttpError::InvalidResponse {
294            url: url.to_string(),
295            message: "Response body is not a string".to_string(),
296        })
297}
298
299#[cfg(target_arch = "wasm32")]
300async fn fetch_bytes_web(url: &str) -> Result<Vec<u8>, HttpError> {
301    use wasm_bindgen::JsCast;
302    use wasm_bindgen_futures::JsFuture;
303    use web_sys::{Request, RequestInit, RequestMode, Response};
304
305    let opts = RequestInit::new();
306    opts.set_method("GET");
307    opts.set_mode(RequestMode::Cors);
308
309    let request =
310        Request::new_with_str_and_init(url, &opts).map_err(|err| HttpError::RequestFailed {
311            url: url.to_string(),
312            message: format!("{:?}", err),
313        })?;
314
315    let window = web_sys::window().ok_or(HttpError::NoWindow)?;
316    let resp_value = JsFuture::from(window.fetch_with_request(&request))
317        .await
318        .map_err(|err| HttpError::RequestFailed {
319            url: url.to_string(),
320            message: format!("{:?}", err),
321        })?;
322
323    let resp: Response = resp_value
324        .dyn_into()
325        .map_err(|_| HttpError::InvalidResponse {
326            url: url.to_string(),
327            message: "Response is not a Response object".to_string(),
328        })?;
329
330    if !resp.ok() {
331        return Err(HttpError::HttpStatus {
332            url: url.to_string(),
333            status: resp.status(),
334        });
335    }
336
337    let bytes_promise = resp
338        .array_buffer()
339        .map_err(|err| HttpError::BodyReadFailed {
340            url: url.to_string(),
341            message: format!("{:?}", err),
342        })?;
343    let bytes_value =
344        JsFuture::from(bytes_promise)
345            .await
346            .map_err(|err| HttpError::BodyReadFailed {
347                url: url.to_string(),
348                message: format!("{:?}", err),
349            })?;
350
351    let array = js_sys::Uint8Array::new(&bytes_value);
352    Ok(array.to_vec())
353}
354
355pub fn default_http_client() -> HttpClientRef {
356    Arc::new(DefaultHttpClient)
357}
358
359pub fn local_http_client() -> CompositionLocal<HttpClientRef> {
360    thread_local! {
361        static LOCAL_HTTP_CLIENT: std::cell::RefCell<Option<CompositionLocal<HttpClientRef>>> = const { std::cell::RefCell::new(None) };
362    }
363
364    LOCAL_HTTP_CLIENT.with(|cell| {
365        let mut local = cell.borrow_mut();
366        if local.is_none() {
367            *local = Some(compositionLocalOf(default_http_client));
368        }
369        local
370            .as_ref()
371            .expect("HTTP client composition local must be initialized")
372            .clone()
373    })
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use crate::run_test_composition;
380    use cranpose_core::CompositionLocalProvider;
381    use std::cell::RefCell;
382    use std::rc::Rc;
383    #[cfg(not(target_arch = "wasm32"))]
384    use std::thread;
385
386    struct TestHttpClient;
387
388    impl HttpClient for TestHttpClient {
389        fn get_text<'a>(&'a self, _url: &'a str) -> HttpFuture<'a, String> {
390            Box::pin(async { Ok("ok".to_string()) })
391        }
392    }
393
394    #[test]
395    fn default_http_client_is_available() {
396        let client = default_http_client();
397        let cloned = client.clone();
398        assert_eq!(Arc::strong_count(&client), 2);
399        drop(cloned);
400        assert_eq!(Arc::strong_count(&client), 1);
401    }
402
403    #[test]
404    fn test_client_uses_default_get_bytes_from_text() {
405        let client = TestHttpClient;
406        let bytes = pollster::block_on(client.get_bytes("https://example.com")).expect("bytes");
407        assert_eq!(bytes, b"ok".to_vec());
408    }
409
410    #[test]
411    fn map_ordered_concurrent_preserves_input_order() {
412        let inputs = [3usize, 1, 4, 1, 5];
413        let outputs = pollster::block_on(map_ordered_concurrent(&inputs, 2, |value| async move {
414            value * 10
415        }));
416
417        assert_eq!(outputs, vec![30, 10, 40, 10, 50]);
418    }
419
420    #[test]
421    fn local_http_client_can_be_overridden() {
422        let local = local_http_client();
423        let default_client = default_http_client();
424        let custom_client: HttpClientRef = Arc::new(TestHttpClient);
425        let captured = Rc::new(RefCell::new(None));
426
427        {
428            let captured_for_closure = Rc::clone(&captured);
429            let custom_client = custom_client.clone();
430            let local_for_provider = local.clone();
431            let local_for_read = local.clone();
432            run_test_composition(move || {
433                let captured = Rc::clone(&captured_for_closure);
434                let local_for_read = local_for_read.clone();
435                CompositionLocalProvider(
436                    vec![local_for_provider.provides(custom_client.clone())],
437                    move || {
438                        let current = local_for_read.current();
439                        *captured.borrow_mut() = Some(current);
440                    },
441                );
442            });
443        }
444
445        let current = captured.borrow().as_ref().expect("client captured").clone();
446        assert!(Arc::ptr_eq(&current, &custom_client));
447        assert!(!Arc::ptr_eq(&current, &default_client));
448    }
449
450    #[cfg(not(target_arch = "wasm32"))]
451    #[test]
452    fn native_http_client_builds() {
453        build_native_client().expect("native HTTP client should initialize");
454    }
455
456    #[cfg(not(target_arch = "wasm32"))]
457    #[test]
458    fn certificates_from_der_chain_accepts_valid_roots() {
459        let certificates = certificates_from_der_chain(
460            webpki_root_certs::TLS_SERVER_ROOT_CERTS
461                .iter()
462                .take(3)
463                .map(|certificate| certificate.as_ref()),
464        )
465        .expect("root certificates should parse");
466
467        assert_eq!(certificates.len(), 3);
468    }
469
470    #[cfg(not(target_arch = "wasm32"))]
471    #[test]
472    fn default_http_client_fetches_text_from_local_server() {
473        use std::io::{Read, Write};
474        use std::net::TcpListener;
475
476        let listener = TcpListener::bind("127.0.0.1:0").expect("bind local test server");
477        let address = listener
478            .local_addr()
479            .expect("read local test server address");
480        let server = thread::spawn(move || {
481            let (mut stream, _) = listener.accept().expect("accept local test request");
482            let mut request = [0_u8; 1024];
483            let _ = stream.read(&mut request).expect("read local test request");
484            let body = "cranpose-http-test";
485            write!(
486                stream,
487                "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
488                body.len(),
489                body
490            )
491            .expect("write local test response");
492        });
493
494        let url = format!("http://{address}");
495        let text = pollster::block_on(default_http_client().get_text(&url))
496            .expect("fetch text from local test server");
497        server.join().expect("join local test server");
498
499        assert_eq!(text, "cranpose-http-test");
500    }
501}