Skip to main content

cranpose_services/
http.rs

1use cranpose_core::{compositionLocalOfWithPolicy, CompositionLocal};
2#[cfg(target_arch = "wasm32")]
3use futures_util::{stream, StreamExt};
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7
8#[derive(thiserror::Error, Debug, Clone)]
9pub enum HttpError {
10    #[error("Failed to build HTTP client: {0}")]
11    ClientInit(String),
12    #[error("Request failed for {url}: {message}")]
13    RequestFailed { url: String, message: String },
14    #[error("Request failed with status {status} for {url}")]
15    HttpStatus { url: String, status: u16 },
16    #[error("Failed to read response body for {url}: {message}")]
17    BodyReadFailed { url: String, message: String },
18    #[error("Invalid response for {url}: {message}")]
19    InvalidResponse { url: String, message: String },
20    #[error("No window object available")]
21    NoWindow,
22}
23
24#[cfg(not(target_arch = "wasm32"))]
25pub type HttpFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, HttpError>> + Send + 'a>>;
26
27#[cfg(target_arch = "wasm32")]
28pub type HttpFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, HttpError>> + 'a>>;
29
30pub trait HttpClient: Send + Sync {
31    fn get_text<'a>(&'a self, url: &'a str) -> HttpFuture<'a, String>;
32
33    fn get_bytes<'a>(&'a self, url: &'a str) -> HttpFuture<'a, Vec<u8>> {
34        Box::pin(async move { self.get_text(url).await.map(|text| text.into_bytes()) })
35    }
36}
37
38pub type HttpClientRef = Arc<dyn HttpClient>;
39
40#[cfg(not(target_arch = "wasm32"))]
41pub async fn map_ordered_concurrent<I, T, F, Fut>(
42    items: &[I],
43    concurrency: usize,
44    task: F,
45) -> Vec<T>
46where
47    I: Clone + Send,
48    T: Send,
49    F: Fn(I) -> Fut + Send + Sync + 'static,
50    Fut: Future<Output = T> + Send,
51{
52    let task = Arc::new(task);
53    let mut results = Vec::with_capacity(items.len());
54
55    for chunk in items.chunks(concurrency.max(1)) {
56        std::thread::scope(|scope| {
57            let mut handles = Vec::with_capacity(chunk.len());
58            for item in chunk.iter().cloned() {
59                let task = Arc::clone(&task);
60                handles.push(scope.spawn(move || pollster::block_on(task(item))));
61            }
62
63            for handle in handles {
64                results.push(
65                    handle
66                        .join()
67                        .unwrap_or_else(|_| panic!("ordered concurrent worker thread panicked")),
68                );
69            }
70        });
71    }
72
73    results
74}
75
76#[cfg(target_arch = "wasm32")]
77pub async fn map_ordered_concurrent<I, T, F, Fut>(
78    items: &[I],
79    concurrency: usize,
80    task: F,
81) -> Vec<T>
82where
83    I: Clone,
84    F: Fn(I) -> Fut + Clone,
85    Fut: Future<Output = T>,
86{
87    let mut results = stream::iter(items.iter().cloned().enumerate().map(|(index, item)| {
88        let task = task.clone();
89        async move { (index, task(item).await) }
90    }))
91    .buffer_unordered(concurrency.max(1))
92    .collect::<Vec<_>>()
93    .await;
94
95    results.sort_by_key(|(index, _)| *index);
96    results.into_iter().map(|(_, value)| value).collect()
97}
98
99struct DefaultHttpClient;
100
101impl HttpClient for DefaultHttpClient {
102    fn get_text<'a>(&'a self, url: &'a str) -> HttpFuture<'a, String> {
103        Box::pin(async move {
104            #[cfg(not(target_arch = "wasm32"))]
105            {
106                fetch_text_native(url)
107            }
108
109            #[cfg(target_arch = "wasm32")]
110            {
111                fetch_text_web(url).await
112            }
113        })
114    }
115
116    fn get_bytes<'a>(&'a self, url: &'a str) -> HttpFuture<'a, Vec<u8>> {
117        Box::pin(async move {
118            #[cfg(not(target_arch = "wasm32"))]
119            {
120                fetch_bytes_native(url)
121            }
122
123            #[cfg(target_arch = "wasm32")]
124            {
125                fetch_bytes_web(url).await
126            }
127        })
128    }
129}
130
131#[cfg(not(target_arch = "wasm32"))]
132fn fetch_text_native(url: &str) -> Result<String, HttpError> {
133    native_response(url)?
134        .text()
135        .map_err(|err| HttpError::BodyReadFailed {
136            url: url.to_string(),
137            message: err.to_string(),
138        })
139}
140
141#[cfg(not(target_arch = "wasm32"))]
142fn fetch_bytes_native(url: &str) -> Result<Vec<u8>, HttpError> {
143    native_response(url)?
144        .bytes()
145        .map(|bytes| bytes.to_vec())
146        .map_err(|err| HttpError::BodyReadFailed {
147            url: url.to_string(),
148            message: err.to_string(),
149        })
150}
151
152#[cfg(not(target_arch = "wasm32"))]
153fn native_response(url: &str) -> Result<reqwest::blocking::Response, HttpError> {
154    let response = native_client()?
155        .get(url)
156        .send()
157        .map_err(|err| HttpError::RequestFailed {
158            url: url.to_string(),
159            message: err.to_string(),
160        })?;
161
162    let status = response.status();
163    if !status.is_success() {
164        return Err(HttpError::HttpStatus {
165            url: url.to_string(),
166            status: status.as_u16(),
167        });
168    }
169
170    Ok(response)
171}
172
173#[cfg(not(target_arch = "wasm32"))]
174fn native_client() -> Result<&'static reqwest::blocking::Client, HttpError> {
175    use std::sync::OnceLock;
176
177    static CLIENT: OnceLock<Result<reqwest::blocking::Client, HttpError>> = OnceLock::new();
178    CLIENT
179        .get_or_init(build_native_client)
180        .as_ref()
181        .map_err(Clone::clone)
182}
183
184#[cfg(not(target_arch = "wasm32"))]
185fn build_native_client() -> Result<reqwest::blocking::Client, HttpError> {
186    use std::time::Duration;
187
188    configure_native_client_builder(
189        reqwest::blocking::Client::builder()
190            .timeout(Duration::from_secs(10))
191            .user_agent("cranpose/0.1"),
192    )?
193    .build()
194    .map_err(|err| HttpError::ClientInit(err.to_string()))
195}
196
197#[cfg(not(target_arch = "wasm32"))]
198fn configure_native_client_builder(
199    builder: reqwest::blocking::ClientBuilder,
200) -> Result<reqwest::blocking::ClientBuilder, HttpError> {
201    #[cfg(target_os = "android")]
202    {
203        return Ok(builder.tls_certs_only(android_root_certificates()?));
204    }
205
206    #[cfg(not(target_os = "android"))]
207    {
208        Ok(builder)
209    }
210}
211
212#[cfg(target_os = "android")]
213fn android_root_certificates() -> Result<Vec<reqwest::Certificate>, HttpError> {
214    certificates_from_der_chain(
215        webpki_root_certs::TLS_SERVER_ROOT_CERTS
216            .iter()
217            .map(|certificate| certificate.as_ref()),
218    )
219}
220
221#[cfg(any(test, target_os = "android"))]
222fn certificates_from_der_chain<'a, I>(
223    certificates: I,
224) -> Result<Vec<reqwest::Certificate>, HttpError>
225where
226    I: IntoIterator<Item = &'a [u8]>,
227{
228    certificates
229        .into_iter()
230        .enumerate()
231        .map(|(index, der)| {
232            reqwest::Certificate::from_der(der).map_err(|err| {
233                HttpError::ClientInit(format!(
234                    "Failed to load TLS root certificate {index}: {err}"
235                ))
236            })
237        })
238        .collect()
239}
240
241#[cfg(target_arch = "wasm32")]
242async fn fetch_text_web(url: &str) -> Result<String, HttpError> {
243    use wasm_bindgen::JsCast;
244    use wasm_bindgen_futures::JsFuture;
245    use web_sys::{Request, RequestInit, RequestMode, Response};
246
247    let opts = RequestInit::new();
248    opts.set_method("GET");
249    opts.set_mode(RequestMode::Cors);
250
251    let request =
252        Request::new_with_str_and_init(url, &opts).map_err(|err| HttpError::RequestFailed {
253            url: url.to_string(),
254            message: format!("{:?}", err),
255        })?;
256
257    let window = web_sys::window().ok_or(HttpError::NoWindow)?;
258    let resp_value = JsFuture::from(window.fetch_with_request(&request))
259        .await
260        .map_err(|err| HttpError::RequestFailed {
261            url: url.to_string(),
262            message: format!("{:?}", err),
263        })?;
264
265    let resp: Response = resp_value
266        .dyn_into()
267        .map_err(|_| HttpError::InvalidResponse {
268            url: url.to_string(),
269            message: "Response is not a Response object".to_string(),
270        })?;
271
272    if !resp.ok() {
273        return Err(HttpError::HttpStatus {
274            url: url.to_string(),
275            status: resp.status(),
276        });
277    }
278
279    let text_promise = resp.text().map_err(|err| HttpError::BodyReadFailed {
280        url: url.to_string(),
281        message: format!("{:?}", err),
282    })?;
283    let text_value =
284        JsFuture::from(text_promise)
285            .await
286            .map_err(|err| HttpError::BodyReadFailed {
287                url: url.to_string(),
288                message: format!("{:?}", err),
289            })?;
290
291    text_value
292        .as_string()
293        .ok_or_else(|| HttpError::InvalidResponse {
294            url: url.to_string(),
295            message: "Response body is not a string".to_string(),
296        })
297}
298
299#[cfg(target_arch = "wasm32")]
300async fn fetch_bytes_web(url: &str) -> Result<Vec<u8>, HttpError> {
301    use wasm_bindgen::JsCast;
302    use wasm_bindgen_futures::JsFuture;
303    use web_sys::{Request, RequestInit, RequestMode, Response};
304
305    let opts = RequestInit::new();
306    opts.set_method("GET");
307    opts.set_mode(RequestMode::Cors);
308
309    let request =
310        Request::new_with_str_and_init(url, &opts).map_err(|err| HttpError::RequestFailed {
311            url: url.to_string(),
312            message: format!("{:?}", err),
313        })?;
314
315    let window = web_sys::window().ok_or(HttpError::NoWindow)?;
316    let resp_value = JsFuture::from(window.fetch_with_request(&request))
317        .await
318        .map_err(|err| HttpError::RequestFailed {
319            url: url.to_string(),
320            message: format!("{:?}", err),
321        })?;
322
323    let resp: Response = resp_value
324        .dyn_into()
325        .map_err(|_| HttpError::InvalidResponse {
326            url: url.to_string(),
327            message: "Response is not a Response object".to_string(),
328        })?;
329
330    if !resp.ok() {
331        return Err(HttpError::HttpStatus {
332            url: url.to_string(),
333            status: resp.status(),
334        });
335    }
336
337    let bytes_promise = resp
338        .array_buffer()
339        .map_err(|err| HttpError::BodyReadFailed {
340            url: url.to_string(),
341            message: format!("{:?}", err),
342        })?;
343    let bytes_value =
344        JsFuture::from(bytes_promise)
345            .await
346            .map_err(|err| HttpError::BodyReadFailed {
347                url: url.to_string(),
348                message: format!("{:?}", err),
349            })?;
350
351    let array = js_sys::Uint8Array::new(&bytes_value);
352    Ok(array.to_vec())
353}
354
355pub fn default_http_client() -> HttpClientRef {
356    Arc::new(DefaultHttpClient)
357}
358
359pub fn local_http_client() -> CompositionLocal<HttpClientRef> {
360    thread_local! {
361        static LOCAL_HTTP_CLIENT: std::cell::RefCell<Option<CompositionLocal<HttpClientRef>>> = const { std::cell::RefCell::new(None) };
362    }
363
364    LOCAL_HTTP_CLIENT.with(|cell| {
365        let mut local = cell.borrow_mut();
366        if local.is_none() {
367            *local = Some(compositionLocalOfWithPolicy(
368                default_http_client,
369                Arc::ptr_eq,
370            ));
371        }
372        local
373            .as_ref()
374            .expect("HTTP client composition local must be initialized")
375            .clone()
376    })
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use crate::run_test_composition;
383    use cranpose_core::CompositionLocalProvider;
384    use std::cell::RefCell;
385    use std::rc::Rc;
386    #[cfg(not(target_arch = "wasm32"))]
387    use std::thread;
388
389    struct TestHttpClient;
390
391    impl HttpClient for TestHttpClient {
392        fn get_text<'a>(&'a self, _url: &'a str) -> HttpFuture<'a, String> {
393            Box::pin(async { Ok("ok".to_string()) })
394        }
395    }
396
397    #[test]
398    fn default_http_client_is_available() {
399        let client = default_http_client();
400        let cloned = client.clone();
401        assert_eq!(Arc::strong_count(&client), 2);
402        drop(cloned);
403        assert_eq!(Arc::strong_count(&client), 1);
404    }
405
406    #[test]
407    fn test_client_uses_default_get_bytes_from_text() {
408        let client = TestHttpClient;
409        let bytes = pollster::block_on(client.get_bytes("https://example.com")).expect("bytes");
410        assert_eq!(bytes, b"ok".to_vec());
411    }
412
413    #[test]
414    fn map_ordered_concurrent_preserves_input_order() {
415        let inputs = [3usize, 1, 4, 1, 5];
416        let outputs = pollster::block_on(map_ordered_concurrent(&inputs, 2, |value| async move {
417            value * 10
418        }));
419
420        assert_eq!(outputs, vec![30, 10, 40, 10, 50]);
421    }
422
423    #[test]
424    fn local_http_client_can_be_overridden() {
425        let local = local_http_client();
426        let default_client = default_http_client();
427        let custom_client: HttpClientRef = Arc::new(TestHttpClient);
428        let captured = Rc::new(RefCell::new(None));
429
430        {
431            let captured_for_closure = Rc::clone(&captured);
432            let custom_client = custom_client.clone();
433            let local_for_provider = local.clone();
434            let local_for_read = local.clone();
435            run_test_composition(move || {
436                let captured = Rc::clone(&captured_for_closure);
437                let local_for_read = local_for_read.clone();
438                CompositionLocalProvider(
439                    vec![local_for_provider.provides(custom_client.clone())],
440                    move || {
441                        let current = local_for_read.current();
442                        *captured.borrow_mut() = Some(current);
443                    },
444                );
445            });
446        }
447
448        let current = captured.borrow().as_ref().expect("client captured").clone();
449        assert!(Arc::ptr_eq(&current, &custom_client));
450        assert!(!Arc::ptr_eq(&current, &default_client));
451    }
452
453    #[cfg(not(target_arch = "wasm32"))]
454    #[test]
455    fn native_http_client_builds() {
456        build_native_client().expect("native HTTP client should initialize");
457    }
458
459    #[cfg(not(target_arch = "wasm32"))]
460    #[test]
461    fn certificates_from_der_chain_accepts_valid_roots() {
462        let certificates = certificates_from_der_chain(
463            webpki_root_certs::TLS_SERVER_ROOT_CERTS
464                .iter()
465                .take(3)
466                .map(|certificate| certificate.as_ref()),
467        )
468        .expect("root certificates should parse");
469
470        assert_eq!(certificates.len(), 3);
471    }
472
473    #[cfg(not(target_arch = "wasm32"))]
474    #[test]
475    fn default_http_client_fetches_text_from_local_server() {
476        use std::io::{Read, Write};
477        use std::net::TcpListener;
478
479        let listener = match TcpListener::bind("127.0.0.1:0") {
480            Ok(listener) => listener,
481            Err(err) if err.kind() == std::io::ErrorKind::PermissionDenied => {
482                eprintln!("skipping local HTTP server bind in restricted test environment: {err}");
483                return;
484            }
485            Err(err) => panic!("bind local test server: {err}"),
486        };
487        let address = listener
488            .local_addr()
489            .expect("read local test server address");
490        let server = thread::spawn(move || {
491            let (mut stream, _) = listener.accept().expect("accept local test request");
492            let mut request = [0_u8; 1024];
493            let _ = stream.read(&mut request).expect("read local test request");
494            let body = "cranpose-http-test";
495            write!(
496                stream,
497                "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
498                body.len(),
499                body
500            )
501            .expect("write local test response");
502        });
503
504        let url = format!("http://{address}");
505        let text = pollster::block_on(default_http_client().get_text(&url))
506            .expect("fetch text from local test server");
507        server.join().expect("join local test server");
508
509        assert_eq!(text, "cranpose-http-test");
510    }
511}