1use cranpose_core::{compositionLocalOf, CompositionLocal};
2#[cfg(target_arch = "wasm32")]
3use futures_util::{stream, StreamExt};
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7
8#[derive(thiserror::Error, Debug, Clone)]
9pub enum HttpError {
10 #[error("Failed to build HTTP client: {0}")]
11 ClientInit(String),
12 #[error("Request failed for {url}: {message}")]
13 RequestFailed { url: String, message: String },
14 #[error("Request failed with status {status} for {url}")]
15 HttpStatus { url: String, status: u16 },
16 #[error("Failed to read response body for {url}: {message}")]
17 BodyReadFailed { url: String, message: String },
18 #[error("Invalid response for {url}: {message}")]
19 InvalidResponse { url: String, message: String },
20 #[error("No window object available")]
21 NoWindow,
22}
23
24#[cfg(not(target_arch = "wasm32"))]
25pub type HttpFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, HttpError>> + Send + 'a>>;
26
27#[cfg(target_arch = "wasm32")]
28pub type HttpFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, HttpError>> + 'a>>;
29
30pub trait HttpClient: Send + Sync {
31 fn get_text<'a>(&'a self, url: &'a str) -> HttpFuture<'a, String>;
32
33 fn get_bytes<'a>(&'a self, url: &'a str) -> HttpFuture<'a, Vec<u8>> {
34 Box::pin(async move { self.get_text(url).await.map(|text| text.into_bytes()) })
35 }
36}
37
38pub type HttpClientRef = Arc<dyn HttpClient>;
39
40#[cfg(not(target_arch = "wasm32"))]
41pub async fn map_ordered_concurrent<I, T, F, Fut>(
42 items: &[I],
43 concurrency: usize,
44 task: F,
45) -> Vec<T>
46where
47 I: Clone + Send,
48 T: Send,
49 F: Fn(I) -> Fut + Send + Sync + 'static,
50 Fut: Future<Output = T> + Send,
51{
52 let task = Arc::new(task);
53 let mut results = Vec::with_capacity(items.len());
54
55 for chunk in items.chunks(concurrency.max(1)) {
56 std::thread::scope(|scope| {
57 let mut handles = Vec::with_capacity(chunk.len());
58 for item in chunk.iter().cloned() {
59 let task = Arc::clone(&task);
60 handles.push(scope.spawn(move || pollster::block_on(task(item))));
61 }
62
63 for handle in handles {
64 results.push(
65 handle
66 .join()
67 .unwrap_or_else(|_| panic!("ordered concurrent worker thread panicked")),
68 );
69 }
70 });
71 }
72
73 results
74}
75
76#[cfg(target_arch = "wasm32")]
77pub async fn map_ordered_concurrent<I, T, F, Fut>(
78 items: &[I],
79 concurrency: usize,
80 task: F,
81) -> Vec<T>
82where
83 I: Clone,
84 F: Fn(I) -> Fut + Clone,
85 Fut: Future<Output = T>,
86{
87 let mut results = stream::iter(items.iter().cloned().enumerate().map(|(index, item)| {
88 let task = task.clone();
89 async move { (index, task(item).await) }
90 }))
91 .buffer_unordered(concurrency.max(1))
92 .collect::<Vec<_>>()
93 .await;
94
95 results.sort_by_key(|(index, _)| *index);
96 results.into_iter().map(|(_, value)| value).collect()
97}
98
99struct DefaultHttpClient;
100
101impl HttpClient for DefaultHttpClient {
102 fn get_text<'a>(&'a self, url: &'a str) -> HttpFuture<'a, String> {
103 Box::pin(async move {
104 #[cfg(not(target_arch = "wasm32"))]
105 {
106 fetch_text_native(url)
107 }
108
109 #[cfg(target_arch = "wasm32")]
110 {
111 fetch_text_web(url).await
112 }
113 })
114 }
115
116 fn get_bytes<'a>(&'a self, url: &'a str) -> HttpFuture<'a, Vec<u8>> {
117 Box::pin(async move {
118 #[cfg(not(target_arch = "wasm32"))]
119 {
120 fetch_bytes_native(url)
121 }
122
123 #[cfg(target_arch = "wasm32")]
124 {
125 fetch_bytes_web(url).await
126 }
127 })
128 }
129}
130
131#[cfg(not(target_arch = "wasm32"))]
132fn fetch_text_native(url: &str) -> Result<String, HttpError> {
133 native_response(url)?
134 .text()
135 .map_err(|err| HttpError::BodyReadFailed {
136 url: url.to_string(),
137 message: err.to_string(),
138 })
139}
140
141#[cfg(not(target_arch = "wasm32"))]
142fn fetch_bytes_native(url: &str) -> Result<Vec<u8>, HttpError> {
143 native_response(url)?
144 .bytes()
145 .map(|bytes| bytes.to_vec())
146 .map_err(|err| HttpError::BodyReadFailed {
147 url: url.to_string(),
148 message: err.to_string(),
149 })
150}
151
152#[cfg(not(target_arch = "wasm32"))]
153fn native_response(url: &str) -> Result<reqwest::blocking::Response, HttpError> {
154 let response = native_client()?
155 .get(url)
156 .send()
157 .map_err(|err| HttpError::RequestFailed {
158 url: url.to_string(),
159 message: err.to_string(),
160 })?;
161
162 let status = response.status();
163 if !status.is_success() {
164 return Err(HttpError::HttpStatus {
165 url: url.to_string(),
166 status: status.as_u16(),
167 });
168 }
169
170 Ok(response)
171}
172
173#[cfg(not(target_arch = "wasm32"))]
174fn native_client() -> Result<&'static reqwest::blocking::Client, HttpError> {
175 use std::sync::OnceLock;
176
177 static CLIENT: OnceLock<Result<reqwest::blocking::Client, HttpError>> = OnceLock::new();
178 CLIENT
179 .get_or_init(build_native_client)
180 .as_ref()
181 .map_err(Clone::clone)
182}
183
184#[cfg(not(target_arch = "wasm32"))]
185fn build_native_client() -> Result<reqwest::blocking::Client, HttpError> {
186 use std::time::Duration;
187
188 configure_native_client_builder(
189 reqwest::blocking::Client::builder()
190 .timeout(Duration::from_secs(10))
191 .user_agent("cranpose/0.1"),
192 )?
193 .build()
194 .map_err(|err| HttpError::ClientInit(err.to_string()))
195}
196
197#[cfg(not(target_arch = "wasm32"))]
198fn configure_native_client_builder(
199 builder: reqwest::blocking::ClientBuilder,
200) -> Result<reqwest::blocking::ClientBuilder, HttpError> {
201 #[cfg(target_os = "android")]
202 {
203 return Ok(builder.tls_certs_only(android_root_certificates()?));
204 }
205
206 #[cfg(not(target_os = "android"))]
207 {
208 Ok(builder)
209 }
210}
211
212#[cfg(target_os = "android")]
213fn android_root_certificates() -> Result<Vec<reqwest::Certificate>, HttpError> {
214 certificates_from_der_chain(
215 webpki_root_certs::TLS_SERVER_ROOT_CERTS
216 .iter()
217 .map(|certificate| certificate.as_ref()),
218 )
219}
220
221#[cfg(any(test, target_os = "android"))]
222fn certificates_from_der_chain<'a, I>(
223 certificates: I,
224) -> Result<Vec<reqwest::Certificate>, HttpError>
225where
226 I: IntoIterator<Item = &'a [u8]>,
227{
228 certificates
229 .into_iter()
230 .enumerate()
231 .map(|(index, der)| {
232 reqwest::Certificate::from_der(der).map_err(|err| {
233 HttpError::ClientInit(format!(
234 "Failed to load TLS root certificate {index}: {err}"
235 ))
236 })
237 })
238 .collect()
239}
240
241#[cfg(target_arch = "wasm32")]
242async fn fetch_text_web(url: &str) -> Result<String, HttpError> {
243 use wasm_bindgen::JsCast;
244 use wasm_bindgen_futures::JsFuture;
245 use web_sys::{Request, RequestInit, RequestMode, Response};
246
247 let opts = RequestInit::new();
248 opts.set_method("GET");
249 opts.set_mode(RequestMode::Cors);
250
251 let request =
252 Request::new_with_str_and_init(url, &opts).map_err(|err| HttpError::RequestFailed {
253 url: url.to_string(),
254 message: format!("{:?}", err),
255 })?;
256
257 let window = web_sys::window().ok_or(HttpError::NoWindow)?;
258 let resp_value = JsFuture::from(window.fetch_with_request(&request))
259 .await
260 .map_err(|err| HttpError::RequestFailed {
261 url: url.to_string(),
262 message: format!("{:?}", err),
263 })?;
264
265 let resp: Response = resp_value
266 .dyn_into()
267 .map_err(|_| HttpError::InvalidResponse {
268 url: url.to_string(),
269 message: "Response is not a Response object".to_string(),
270 })?;
271
272 if !resp.ok() {
273 return Err(HttpError::HttpStatus {
274 url: url.to_string(),
275 status: resp.status(),
276 });
277 }
278
279 let text_promise = resp.text().map_err(|err| HttpError::BodyReadFailed {
280 url: url.to_string(),
281 message: format!("{:?}", err),
282 })?;
283 let text_value =
284 JsFuture::from(text_promise)
285 .await
286 .map_err(|err| HttpError::BodyReadFailed {
287 url: url.to_string(),
288 message: format!("{:?}", err),
289 })?;
290
291 text_value
292 .as_string()
293 .ok_or_else(|| HttpError::InvalidResponse {
294 url: url.to_string(),
295 message: "Response body is not a string".to_string(),
296 })
297}
298
299#[cfg(target_arch = "wasm32")]
300async fn fetch_bytes_web(url: &str) -> Result<Vec<u8>, HttpError> {
301 use wasm_bindgen::JsCast;
302 use wasm_bindgen_futures::JsFuture;
303 use web_sys::{Request, RequestInit, RequestMode, Response};
304
305 let opts = RequestInit::new();
306 opts.set_method("GET");
307 opts.set_mode(RequestMode::Cors);
308
309 let request =
310 Request::new_with_str_and_init(url, &opts).map_err(|err| HttpError::RequestFailed {
311 url: url.to_string(),
312 message: format!("{:?}", err),
313 })?;
314
315 let window = web_sys::window().ok_or(HttpError::NoWindow)?;
316 let resp_value = JsFuture::from(window.fetch_with_request(&request))
317 .await
318 .map_err(|err| HttpError::RequestFailed {
319 url: url.to_string(),
320 message: format!("{:?}", err),
321 })?;
322
323 let resp: Response = resp_value
324 .dyn_into()
325 .map_err(|_| HttpError::InvalidResponse {
326 url: url.to_string(),
327 message: "Response is not a Response object".to_string(),
328 })?;
329
330 if !resp.ok() {
331 return Err(HttpError::HttpStatus {
332 url: url.to_string(),
333 status: resp.status(),
334 });
335 }
336
337 let bytes_promise = resp
338 .array_buffer()
339 .map_err(|err| HttpError::BodyReadFailed {
340 url: url.to_string(),
341 message: format!("{:?}", err),
342 })?;
343 let bytes_value =
344 JsFuture::from(bytes_promise)
345 .await
346 .map_err(|err| HttpError::BodyReadFailed {
347 url: url.to_string(),
348 message: format!("{:?}", err),
349 })?;
350
351 let array = js_sys::Uint8Array::new(&bytes_value);
352 Ok(array.to_vec())
353}
354
355pub fn default_http_client() -> HttpClientRef {
356 Arc::new(DefaultHttpClient)
357}
358
359pub fn local_http_client() -> CompositionLocal<HttpClientRef> {
360 thread_local! {
361 static LOCAL_HTTP_CLIENT: std::cell::RefCell<Option<CompositionLocal<HttpClientRef>>> = const { std::cell::RefCell::new(None) };
362 }
363
364 LOCAL_HTTP_CLIENT.with(|cell| {
365 let mut local = cell.borrow_mut();
366 if local.is_none() {
367 *local = Some(compositionLocalOf(default_http_client));
368 }
369 local
370 .as_ref()
371 .expect("HTTP client composition local must be initialized")
372 .clone()
373 })
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use crate::run_test_composition;
380 use cranpose_core::CompositionLocalProvider;
381 use std::cell::RefCell;
382 use std::rc::Rc;
383 #[cfg(not(target_arch = "wasm32"))]
384 use std::thread;
385
386 struct TestHttpClient;
387
388 impl HttpClient for TestHttpClient {
389 fn get_text<'a>(&'a self, _url: &'a str) -> HttpFuture<'a, String> {
390 Box::pin(async { Ok("ok".to_string()) })
391 }
392 }
393
394 #[test]
395 fn default_http_client_is_available() {
396 let client = default_http_client();
397 let cloned = client.clone();
398 assert_eq!(Arc::strong_count(&client), 2);
399 drop(cloned);
400 assert_eq!(Arc::strong_count(&client), 1);
401 }
402
403 #[test]
404 fn test_client_uses_default_get_bytes_from_text() {
405 let client = TestHttpClient;
406 let bytes = pollster::block_on(client.get_bytes("https://example.com")).expect("bytes");
407 assert_eq!(bytes, b"ok".to_vec());
408 }
409
410 #[test]
411 fn map_ordered_concurrent_preserves_input_order() {
412 let inputs = [3usize, 1, 4, 1, 5];
413 let outputs = pollster::block_on(map_ordered_concurrent(&inputs, 2, |value| async move {
414 value * 10
415 }));
416
417 assert_eq!(outputs, vec![30, 10, 40, 10, 50]);
418 }
419
420 #[test]
421 fn local_http_client_can_be_overridden() {
422 let local = local_http_client();
423 let default_client = default_http_client();
424 let custom_client: HttpClientRef = Arc::new(TestHttpClient);
425 let captured = Rc::new(RefCell::new(None));
426
427 {
428 let captured_for_closure = Rc::clone(&captured);
429 let custom_client = custom_client.clone();
430 let local_for_provider = local.clone();
431 let local_for_read = local.clone();
432 run_test_composition(move || {
433 let captured = Rc::clone(&captured_for_closure);
434 let local_for_read = local_for_read.clone();
435 CompositionLocalProvider(
436 vec![local_for_provider.provides(custom_client.clone())],
437 move || {
438 let current = local_for_read.current();
439 *captured.borrow_mut() = Some(current);
440 },
441 );
442 });
443 }
444
445 let current = captured.borrow().as_ref().expect("client captured").clone();
446 assert!(Arc::ptr_eq(¤t, &custom_client));
447 assert!(!Arc::ptr_eq(¤t, &default_client));
448 }
449
450 #[cfg(not(target_arch = "wasm32"))]
451 #[test]
452 fn native_http_client_builds() {
453 build_native_client().expect("native HTTP client should initialize");
454 }
455
456 #[cfg(not(target_arch = "wasm32"))]
457 #[test]
458 fn certificates_from_der_chain_accepts_valid_roots() {
459 let certificates = certificates_from_der_chain(
460 webpki_root_certs::TLS_SERVER_ROOT_CERTS
461 .iter()
462 .take(3)
463 .map(|certificate| certificate.as_ref()),
464 )
465 .expect("root certificates should parse");
466
467 assert_eq!(certificates.len(), 3);
468 }
469
470 #[cfg(not(target_arch = "wasm32"))]
471 #[test]
472 fn default_http_client_fetches_text_from_local_server() {
473 use std::io::{Read, Write};
474 use std::net::TcpListener;
475
476 let listener = TcpListener::bind("127.0.0.1:0").expect("bind local test server");
477 let address = listener
478 .local_addr()
479 .expect("read local test server address");
480 let server = thread::spawn(move || {
481 let (mut stream, _) = listener.accept().expect("accept local test request");
482 let mut request = [0_u8; 1024];
483 let _ = stream.read(&mut request).expect("read local test request");
484 let body = "cranpose-http-test";
485 write!(
486 stream,
487 "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
488 body.len(),
489 body
490 )
491 .expect("write local test response");
492 });
493
494 let url = format!("http://{address}");
495 let text = pollster::block_on(default_http_client().get_text(&url))
496 .expect("fetch text from local test server");
497 server.join().expect("join local test server");
498
499 assert_eq!(text, "cranpose-http-test");
500 }
501}