use {
crate::{
error::{Error, Result},
order_book::{DEFAULT_HTTP_TIMEOUT, MAX_RESPONSE_BYTES},
transport::{HttpMethod, HttpRequest, HttpResponse, HttpTransport},
},
js_sys::{Function, Object, Promise, Reflect, global},
wasm_bindgen::{JsCast, JsValue, closure::Closure},
wasm_bindgen_futures::JsFuture,
};
const FETCH_TIMEOUT_MS: u32 = DEFAULT_HTTP_TIMEOUT.as_secs() as u32 * 1000;
#[derive(Debug, Clone, Copy, Default)]
pub struct FetchTransport;
impl HttpTransport for FetchTransport {
async fn execute(&self, request: HttpRequest) -> Result<HttpResponse> {
let method = match request.method {
HttpMethod::Get => "GET",
HttpMethod::Post => "POST",
HttpMethod::Put => "PUT",
HttpMethod::Delete => "DELETE",
};
let body = request
.json_body
.map(|bytes| String::from_utf8_lossy(&bytes).into_owned());
let (status, body) = fetch(
method,
request.url.as_str(),
body.as_deref(),
request.bearer.as_deref(),
)
.await
.map_err(FetchError::into_cow)?;
Ok(HttpResponse { status, body })
}
}
enum FetchError {
Js(JsValue),
TooLarge,
}
impl From<JsValue> for FetchError {
fn from(value: JsValue) -> Self {
Self::Js(value)
}
}
impl FetchError {
fn into_cow(self) -> Error {
match self {
Self::TooLarge => Error::ResponseTooLarge {
max: MAX_RESPONSE_BYTES,
},
Self::Js(value) => {
Error::TransportFailed(value.as_string().unwrap_or_else(|| format!("{value:?}")))
}
}
}
}
fn get_fn(target: &JsValue, name: &str) -> std::result::Result<Function, JsValue> {
Reflect::get(target, &JsValue::from_str(name))?
.dyn_into::<Function>()
.map_err(|_| JsValue::from_str(&format!("{name} is not a function")))
}
async fn fetch(
method: &str,
url: &str,
body: Option<&str>,
bearer: Option<&str>,
) -> std::result::Result<(u16, String), FetchError> {
let init = Object::new();
Reflect::set(
&init,
&JsValue::from_str("method"),
&JsValue::from_str(method),
)?;
if body.is_some() || bearer.is_some() {
let headers = Object::new();
if body.is_some() {
Reflect::set(
&headers,
&JsValue::from_str("content-type"),
&JsValue::from_str("application/json"),
)?;
}
if let Some(token) = bearer {
Reflect::set(
&headers,
&JsValue::from_str("authorization"),
&JsValue::from_str(&format!("Bearer {token}")),
)?;
}
Reflect::set(&init, &JsValue::from_str("headers"), &headers)?;
}
if let Some(body) = body {
Reflect::set(&init, &JsValue::from_str("body"), &JsValue::from_str(body))?;
}
let global = global();
let abort_guard = AbortGuard::install(&global, &init, FETCH_TIMEOUT_MS)?;
let fetch = get_fn(&global, "fetch")?;
let promise: Promise = fetch
.call2(&global, &JsValue::from_str(url), &init)?
.dyn_into()
.map_err(|_| JsValue::from_str("fetch did not return a Promise"))?;
let response = match JsFuture::from(promise).await {
Ok(r) => r,
Err(err) => {
return Err(FetchError::Js(if abort_guard.fired() {
JsValue::from_str(&format!("request timed out after {FETCH_TIMEOUT_MS} ms"))
} else {
err
}));
}
};
let status: u16 = Reflect::get(&response, &JsValue::from_str("status"))?
.as_f64()
.ok_or_else(|| JsValue::from_str("response.status missing"))? as u16;
if let Some(headers) = Reflect::get(&response, &JsValue::from_str("headers"))
.ok()
.filter(|h| !h.is_undefined() && !h.is_null())
&& let Ok(get) = get_fn(&headers, "get")
&& let Ok(declared) = get.call1(&headers, &JsValue::from_str("content-length"))
&& let Some(declared) = declared.as_string()
&& let Ok(declared) = declared.parse::<u64>()
&& declared > MAX_RESPONSE_BYTES as u64
{
return Err(FetchError::TooLarge);
}
let text_fn = get_fn(&response, "text")?;
let text_promise: Promise = text_fn
.call0(&response)?
.dyn_into()
.map_err(|_| JsValue::from_str("response.text() did not return a Promise"))?;
let body_value = JsFuture::from(text_promise).await?;
drop(abort_guard);
let js_text: js_sys::JsString = body_value
.dyn_into()
.map_err(|_| JsValue::from_str("response body not a string"))?;
if js_text.length() as usize > MAX_RESPONSE_BYTES {
return Err(FetchError::TooLarge);
}
let text = String::from(js_text);
if text.len() > MAX_RESPONSE_BYTES {
return Err(FetchError::TooLarge);
}
Ok((status, text))
}
struct AbortGuard {
global: JsValue,
timer: JsValue,
fired: std::rc::Rc<std::cell::Cell<bool>>,
_on_timeout: Closure<dyn FnMut()>,
}
impl AbortGuard {
fn install(
global: &JsValue,
init: &Object,
timeout_ms: u32,
) -> std::result::Result<Self, JsValue> {
let ctor = get_fn(global, "AbortController")?;
let controller = Reflect::construct(&ctor, &js_sys::Array::new())?;
let signal = Reflect::get(&controller, &JsValue::from_str("signal"))?;
Reflect::set(init, &JsValue::from_str("signal"), &signal)?;
let abort_fn = get_fn(&controller, "abort")?;
let fired = std::rc::Rc::new(std::cell::Cell::new(false));
let fired_clone = fired.clone();
let on_timeout = Closure::wrap(Box::new(move || {
fired_clone.set(true);
let _ = abort_fn.call0(&controller);
}) as Box<dyn FnMut()>);
let set_timeout = get_fn(global, "setTimeout")?;
let timer = set_timeout.call2(
global,
on_timeout.as_ref().unchecked_ref(),
&JsValue::from_f64(f64::from(timeout_ms)),
)?;
Ok(Self {
global: global.clone(),
timer,
fired,
_on_timeout: on_timeout,
})
}
fn fired(&self) -> bool {
self.fired.get()
}
}
impl Drop for AbortGuard {
fn drop(&mut self) {
if let Ok(clear_timeout) = get_fn(&self.global, "clearTimeout") {
let _ = clear_timeout.call1(&self.global, &self.timer);
}
}
}