1use cranpose_core::{compositionLocalOf, CompositionLocal};
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6#[derive(thiserror::Error, Debug, Clone)]
7pub enum HttpError {
8 #[error("Failed to build HTTP client: {0}")]
9 ClientInit(String),
10 #[error("Request failed for {url}: {message}")]
11 RequestFailed { url: String, message: String },
12 #[error("Request failed with status {status} for {url}")]
13 HttpStatus { url: String, status: u16 },
14 #[error("Failed to read response body for {url}: {message}")]
15 BodyReadFailed { url: String, message: String },
16 #[error("Invalid response for {url}: {message}")]
17 InvalidResponse { url: String, message: String },
18 #[error("No window object available")]
19 NoWindow,
20}
21
22#[cfg(not(target_arch = "wasm32"))]
23pub type HttpFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, HttpError>> + Send + 'a>>;
24
25#[cfg(target_arch = "wasm32")]
26pub type HttpFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, HttpError>> + 'a>>;
27
28pub trait HttpClient: Send + Sync {
29 fn get_text<'a>(&'a self, url: &'a str) -> HttpFuture<'a, String>;
30
31 fn get_bytes<'a>(&'a self, url: &'a str) -> HttpFuture<'a, Vec<u8>> {
32 Box::pin(async move { self.get_text(url).await.map(|text| text.into_bytes()) })
33 }
34}
35
36pub type HttpClientRef = Arc<dyn HttpClient>;
37
38struct DefaultHttpClient;
39
40impl HttpClient for DefaultHttpClient {
41 fn get_text<'a>(&'a self, url: &'a str) -> HttpFuture<'a, String> {
42 Box::pin(async move {
43 #[cfg(not(target_arch = "wasm32"))]
44 {
45 fetch_text_native(url)
46 }
47
48 #[cfg(target_arch = "wasm32")]
49 {
50 fetch_text_web(url).await
51 }
52 })
53 }
54
55 fn get_bytes<'a>(&'a self, url: &'a str) -> HttpFuture<'a, Vec<u8>> {
56 Box::pin(async move {
57 #[cfg(not(target_arch = "wasm32"))]
58 {
59 fetch_bytes_native(url)
60 }
61
62 #[cfg(target_arch = "wasm32")]
63 {
64 fetch_bytes_web(url).await
65 }
66 })
67 }
68}
69
70#[cfg(not(target_arch = "wasm32"))]
71fn fetch_text_native(url: &str) -> Result<String, HttpError> {
72 native_response(url)?
73 .text()
74 .map_err(|err| HttpError::BodyReadFailed {
75 url: url.to_string(),
76 message: err.to_string(),
77 })
78}
79
80#[cfg(not(target_arch = "wasm32"))]
81fn fetch_bytes_native(url: &str) -> Result<Vec<u8>, HttpError> {
82 native_response(url)?
83 .bytes()
84 .map(|bytes| bytes.to_vec())
85 .map_err(|err| HttpError::BodyReadFailed {
86 url: url.to_string(),
87 message: err.to_string(),
88 })
89}
90
91#[cfg(not(target_arch = "wasm32"))]
92fn native_response(url: &str) -> Result<reqwest::blocking::Response, HttpError> {
93 let response = native_client()?
94 .get(url)
95 .send()
96 .map_err(|err| HttpError::RequestFailed {
97 url: url.to_string(),
98 message: err.to_string(),
99 })?;
100
101 let status = response.status();
102 if !status.is_success() {
103 return Err(HttpError::HttpStatus {
104 url: url.to_string(),
105 status: status.as_u16(),
106 });
107 }
108
109 Ok(response)
110}
111
112#[cfg(not(target_arch = "wasm32"))]
113fn native_client() -> Result<&'static reqwest::blocking::Client, HttpError> {
114 use std::sync::OnceLock;
115
116 static CLIENT: OnceLock<Result<reqwest::blocking::Client, HttpError>> = OnceLock::new();
117 CLIENT
118 .get_or_init(build_native_client)
119 .as_ref()
120 .map_err(Clone::clone)
121}
122
123#[cfg(not(target_arch = "wasm32"))]
124fn build_native_client() -> Result<reqwest::blocking::Client, HttpError> {
125 use std::time::Duration;
126
127 reqwest::blocking::Client::builder()
128 .timeout(Duration::from_secs(10))
129 .user_agent("cranpose/0.1")
130 .build()
131 .map_err(|err| HttpError::ClientInit(err.to_string()))
132}
133
134#[cfg(target_arch = "wasm32")]
135async fn fetch_text_web(url: &str) -> Result<String, HttpError> {
136 use wasm_bindgen::JsCast;
137 use wasm_bindgen_futures::JsFuture;
138 use web_sys::{Request, RequestInit, RequestMode, Response};
139
140 let opts = RequestInit::new();
141 opts.set_method("GET");
142 opts.set_mode(RequestMode::Cors);
143
144 let request =
145 Request::new_with_str_and_init(url, &opts).map_err(|err| HttpError::RequestFailed {
146 url: url.to_string(),
147 message: format!("{:?}", err),
148 })?;
149
150 let window = web_sys::window().ok_or(HttpError::NoWindow)?;
151 let resp_value = JsFuture::from(window.fetch_with_request(&request))
152 .await
153 .map_err(|err| HttpError::RequestFailed {
154 url: url.to_string(),
155 message: format!("{:?}", err),
156 })?;
157
158 let resp: Response = resp_value
159 .dyn_into()
160 .map_err(|_| HttpError::InvalidResponse {
161 url: url.to_string(),
162 message: "Response is not a Response object".to_string(),
163 })?;
164
165 if !resp.ok() {
166 return Err(HttpError::HttpStatus {
167 url: url.to_string(),
168 status: resp.status(),
169 });
170 }
171
172 let text_promise = resp.text().map_err(|err| HttpError::BodyReadFailed {
173 url: url.to_string(),
174 message: format!("{:?}", err),
175 })?;
176 let text_value =
177 JsFuture::from(text_promise)
178 .await
179 .map_err(|err| HttpError::BodyReadFailed {
180 url: url.to_string(),
181 message: format!("{:?}", err),
182 })?;
183
184 text_value
185 .as_string()
186 .ok_or_else(|| HttpError::InvalidResponse {
187 url: url.to_string(),
188 message: "Response body is not a string".to_string(),
189 })
190}
191
192#[cfg(target_arch = "wasm32")]
193async fn fetch_bytes_web(url: &str) -> Result<Vec<u8>, HttpError> {
194 use wasm_bindgen::JsCast;
195 use wasm_bindgen_futures::JsFuture;
196 use web_sys::{Request, RequestInit, RequestMode, Response};
197
198 let opts = RequestInit::new();
199 opts.set_method("GET");
200 opts.set_mode(RequestMode::Cors);
201
202 let request =
203 Request::new_with_str_and_init(url, &opts).map_err(|err| HttpError::RequestFailed {
204 url: url.to_string(),
205 message: format!("{:?}", err),
206 })?;
207
208 let window = web_sys::window().ok_or(HttpError::NoWindow)?;
209 let resp_value = JsFuture::from(window.fetch_with_request(&request))
210 .await
211 .map_err(|err| HttpError::RequestFailed {
212 url: url.to_string(),
213 message: format!("{:?}", err),
214 })?;
215
216 let resp: Response = resp_value
217 .dyn_into()
218 .map_err(|_| HttpError::InvalidResponse {
219 url: url.to_string(),
220 message: "Response is not a Response object".to_string(),
221 })?;
222
223 if !resp.ok() {
224 return Err(HttpError::HttpStatus {
225 url: url.to_string(),
226 status: resp.status(),
227 });
228 }
229
230 let bytes_promise = resp
231 .array_buffer()
232 .map_err(|err| HttpError::BodyReadFailed {
233 url: url.to_string(),
234 message: format!("{:?}", err),
235 })?;
236 let bytes_value =
237 JsFuture::from(bytes_promise)
238 .await
239 .map_err(|err| HttpError::BodyReadFailed {
240 url: url.to_string(),
241 message: format!("{:?}", err),
242 })?;
243
244 let array = js_sys::Uint8Array::new(&bytes_value);
245 Ok(array.to_vec())
246}
247
248pub fn default_http_client() -> HttpClientRef {
249 Arc::new(DefaultHttpClient)
250}
251
252pub fn local_http_client() -> CompositionLocal<HttpClientRef> {
253 thread_local! {
254 static LOCAL_HTTP_CLIENT: std::cell::RefCell<Option<CompositionLocal<HttpClientRef>>> = const { std::cell::RefCell::new(None) };
255 }
256
257 LOCAL_HTTP_CLIENT.with(|cell| {
258 let mut local = cell.borrow_mut();
259 if local.is_none() {
260 *local = Some(compositionLocalOf(default_http_client));
261 }
262 local
263 .as_ref()
264 .expect("HTTP client composition local must be initialized")
265 .clone()
266 })
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::run_test_composition;
273 use cranpose_core::CompositionLocalProvider;
274 use std::cell::RefCell;
275 use std::rc::Rc;
276 #[cfg(not(target_arch = "wasm32"))]
277 use std::thread;
278
279 struct TestHttpClient;
280
281 impl HttpClient for TestHttpClient {
282 fn get_text<'a>(&'a self, _url: &'a str) -> HttpFuture<'a, String> {
283 Box::pin(async { Ok("ok".to_string()) })
284 }
285 }
286
287 #[test]
288 fn default_http_client_is_available() {
289 let client = default_http_client();
290 let cloned = client.clone();
291 assert_eq!(Arc::strong_count(&client), 2);
292 drop(cloned);
293 assert_eq!(Arc::strong_count(&client), 1);
294 }
295
296 #[test]
297 fn test_client_uses_default_get_bytes_from_text() {
298 let client = TestHttpClient;
299 let bytes = pollster::block_on(client.get_bytes("https://example.com")).expect("bytes");
300 assert_eq!(bytes, b"ok".to_vec());
301 }
302
303 #[test]
304 fn local_http_client_can_be_overridden() {
305 let local = local_http_client();
306 let default_client = default_http_client();
307 let custom_client: HttpClientRef = Arc::new(TestHttpClient);
308 let captured = Rc::new(RefCell::new(None));
309
310 {
311 let captured_for_closure = Rc::clone(&captured);
312 let custom_client = custom_client.clone();
313 let local_for_provider = local.clone();
314 let local_for_read = local.clone();
315 run_test_composition(move || {
316 let captured = Rc::clone(&captured_for_closure);
317 let local_for_read = local_for_read.clone();
318 CompositionLocalProvider(
319 vec![local_for_provider.provides(custom_client.clone())],
320 move || {
321 let current = local_for_read.current();
322 *captured.borrow_mut() = Some(current);
323 },
324 );
325 });
326 }
327
328 let current = captured.borrow().as_ref().expect("client captured").clone();
329 assert!(Arc::ptr_eq(¤t, &custom_client));
330 assert!(!Arc::ptr_eq(¤t, &default_client));
331 }
332
333 #[cfg(not(target_arch = "wasm32"))]
334 #[test]
335 fn native_http_client_builds() {
336 build_native_client().expect("native HTTP client should initialize");
337 }
338
339 #[cfg(not(target_arch = "wasm32"))]
340 #[test]
341 fn default_http_client_fetches_text_from_local_server() {
342 use std::io::{Read, Write};
343 use std::net::TcpListener;
344
345 let listener = TcpListener::bind("127.0.0.1:0").expect("bind local test server");
346 let address = listener
347 .local_addr()
348 .expect("read local test server address");
349 let server = thread::spawn(move || {
350 let (mut stream, _) = listener.accept().expect("accept local test request");
351 let mut request = [0_u8; 1024];
352 let _ = stream.read(&mut request).expect("read local test request");
353 let body = "cranpose-http-test";
354 write!(
355 stream,
356 "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
357 body.len(),
358 body
359 )
360 .expect("write local test response");
361 });
362
363 let url = format!("http://{address}");
364 let text = pollster::block_on(default_http_client().get_text(&url))
365 .expect("fetch text from local test server");
366 server.join().expect("join local test server");
367
368 assert_eq!(text, "cranpose-http-test");
369 }
370}