use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use bonsaidb_core::{networking::Payload, schema::ApiName};
use bonsaidb_utils::fast_async_lock;
use flume::Receiver;
use url::Url;
use wasm_bindgen::{closure::Closure, JsCast, JsValue};
use web_sys::{CloseEvent, ErrorEvent, MessageEvent, WebSocket};
use crate::{
client::{AnyApiCallback, OutstandingRequestMapHandle, PendingRequest, SubscriberMap},
Error,
};
pub fn spawn_client(
url: Arc<Url>,
protocol_version: &'static str,
request_receiver: Receiver<PendingRequest>,
custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
subscribers: SubscriberMap,
) {
wasm_bindgen_futures::spawn_local(create_websocket(
url,
protocol_version,
request_receiver,
custom_apis,
subscribers,
));
}
async fn create_websocket(
url: Arc<Url>,
protocol_version: &'static str,
request_receiver: Receiver<PendingRequest>,
custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
subscribers: SubscriberMap,
) {
subscribers.clear();
let initial_request = match request_receiver.recv_async().await {
Ok(r) => r,
Err(_) => return,
};
let ws = match WebSocket::new_with_str(&url.to_string(), protocol_version) {
Ok(ws) => ws,
Err(err) => {
drop(
initial_request
.responder
.send(Err(Error::from(WebSocketError::from(err)))),
);
spawn_client(
url,
protocol_version,
request_receiver,
custom_apis.clone(),
subscribers,
);
return;
}
};
let (connection_request_sender, connection_request_receiver) = flume::unbounded();
let (shutdown_sender, shutdown_receiver) = flume::unbounded();
forward_request_with_shutdown(
request_receiver.clone(),
shutdown_receiver,
connection_request_sender,
);
let initial_request = Arc::new(Mutex::new(Some(initial_request)));
ws.set_binary_type(web_sys::BinaryType::Arraybuffer);
let outstanding_requests = OutstandingRequestMapHandle::default();
let onopen_callback = on_open_callback(
connection_request_receiver,
initial_request.clone(),
outstanding_requests.clone(),
ws.clone(),
);
ws.set_onopen(Some(onopen_callback.as_ref().unchecked_ref()));
let onmessage_callback = on_message_callback(outstanding_requests, custom_apis.clone());
ws.set_onmessage(Some(onmessage_callback.as_ref().unchecked_ref()));
let onerror_callback =
on_error_callback(ws.clone(), initial_request.clone(), shutdown_sender.clone());
ws.set_onerror(Some(onerror_callback.as_ref().unchecked_ref()));
let onclose_callback = on_close_callback(
url.clone(),
protocol_version,
request_receiver.clone(),
shutdown_sender,
ws.clone(),
initial_request,
custom_apis.clone(),
subscribers.clone(),
);
ws.set_onclose(Some(onclose_callback.as_ref().unchecked_ref()));
}
#[allow(clippy::mut_mut)] fn forward_request_with_shutdown(
request_receiver: flume::Receiver<PendingRequest>,
shutdown_receiver: flume::Receiver<()>,
request_sender: flume::Sender<PendingRequest>,
) {
wasm_bindgen_futures::spawn_local(async move {
let mut receive_request = Box::pin(request_receiver.recv_async());
let mut receive_shutdown = Box::pin(shutdown_receiver.recv_async());
loop {
let res = futures::select! {
request = receive_request => request,
_ = receive_shutdown => Err(flume::RecvError::Disconnected)
};
if let Ok(request) = res {
if request_sender.send(request).is_err() {
break;
}
} else {
break;
}
}
});
}
fn on_open_callback(
request_receiver: Receiver<PendingRequest>,
initial_request: Arc<Mutex<Option<PendingRequest>>>,
requests: OutstandingRequestMapHandle,
ws: WebSocket,
) -> JsValue {
Closure::once_into_js(move || {
wasm_bindgen_futures::spawn_local(async move {
if let Some(initial_request) = take_initial_request(&initial_request) {
if send_request(&ws, initial_request, &requests).await {
while let Ok(pending) = request_receiver.recv_async().await {
if !send_request(&ws, pending, &requests).await {
break;
}
}
}
}
drop(ws.close());
drop(ws);
});
})
}
#[allow(clippy::future_not_send)]
async fn send_request(
ws: &WebSocket,
pending: PendingRequest,
requests: &OutstandingRequestMapHandle,
) -> bool {
let mut outstanding_requests = fast_async_lock!(requests);
let bytes = match bincode::serialize(&pending.request) {
Ok(bytes) => bytes,
Err(err) => {
drop(pending.responder.send(Err(Error::from(err))));
return true;
}
};
match ws.send_with_u8_array(&bytes) {
Ok(_) => {
outstanding_requests.insert(
pending.request.id.expect("all requests must have ids"),
pending,
);
true
}
Err(err) => {
drop(
pending
.responder
.send(Err(Error::from(WebSocketError::from(err)))),
);
false
}
}
}
fn on_message_callback(
outstanding_requests: OutstandingRequestMapHandle,
custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
) -> JsValue {
Closure::wrap(Box::new(move |e: MessageEvent| {
if let Ok(abuf) = e.data().dyn_into::<js_sys::ArrayBuffer>() {
let array = js_sys::Uint8Array::new(&abuf);
let payload = match bincode::deserialize::<Payload>(&array.to_vec()) {
Ok(payload) => payload,
Err(err) => {
log::error!("error deserializing response: {:?}", err);
return;
}
};
let outstanding_requests = outstanding_requests.clone();
let custom_apis = custom_apis.clone();
wasm_bindgen_futures::spawn_local(async move {
super::process_response_payload(payload, &outstanding_requests, &custom_apis).await;
});
} else {
log::warn!("Unexpected WebSocket message received: {:?}", e.data());
}
}) as Box<dyn FnMut(MessageEvent)>)
.into_js_value()
}
fn on_error_callback(
ws: WebSocket,
initial_request: Arc<Mutex<Option<PendingRequest>>>,
shutdown: flume::Sender<()>,
) -> JsValue {
Closure::once_into_js(move |e: ErrorEvent| {
ws.set_onerror(None);
let _ = shutdown.send(());
if let Some(initial_request) = take_initial_request(&initial_request) {
drop(
initial_request
.responder
.send(Err(Error::from(WebSocketError(
e.error().as_string().unwrap_or_default(),
)))),
);
} else {
log::error!(
"websocket error '{}'",
e.error().as_string().unwrap_or_default()
);
}
ws.close().unwrap();
})
}
fn take_initial_request(initial_request: &Mutex<Option<PendingRequest>>) -> Option<PendingRequest> {
let mut initial_request = initial_request.lock().unwrap();
initial_request.take()
}
#[allow(clippy::too_many_arguments)]
fn on_close_callback(
url: Arc<Url>,
protocol_version: &'static str,
request_receiver: Receiver<PendingRequest>,
shutdown: flume::Sender<()>,
ws: WebSocket,
initial_request: Arc<Mutex<Option<PendingRequest>>>,
custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
subscribers: SubscriberMap,
) -> JsValue {
Closure::once_into_js(move |c: CloseEvent| {
let _ = shutdown.send(());
ws.set_onclose(None);
if let Some(initial_request) = take_initial_request(&initial_request) {
drop(
initial_request
.responder
.send(Err(Error::from(WebSocketError(format!(
"connection closed ({}). Reason: {:?}",
c.code(),
c.reason()
))))),
);
} else {
log::error!("websocket closed ({}): {:?}", c.code(), c.reason());
}
spawn_client(
url,
protocol_version,
request_receiver,
custom_apis.clone(),
subscribers,
);
})
}
#[derive(thiserror::Error, Debug)]
#[error("WebSocket error: {0}")]
pub struct WebSocketError(String);
impl From<JsValue> for WebSocketError {
fn from(value: JsValue) -> Self {
Self(if let Some(value) = value.as_string() {
value
} else if let Some(value) = value.as_f64() {
value.to_string()
} else if let Some(value) = value.as_bool() {
value.to_string()
} else if value.is_null() {
String::from("(null)")
} else {
String::from("(undefined)")
})
}
}