use async_trait::async_trait;
use bytes::Bytes;
use std::io::{self, Read, Write};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use wasm_bindgen::prelude::*;
use crate::hub::event::{sleep, GenericSource, IOInterest, IONotifier};
use crate::hub::transport::pool::{AddressedStreamFactory, NonBlockingStream};
use crate::hub::transport::{IOSource, TransportError};
use crate::hub::utils::{error, info, to_js_error, warn};
struct SendableWS(send_wrapper::SendWrapper<web_sys::WebSocket>);
impl SendableWS {
fn new(ws: web_sys::WebSocket) -> Self {
Self(send_wrapper::SendWrapper::new(ws))
}
fn set_onopen(&mut self, callback: Option<Box<dyn Fn() + Send>>) {
match callback {
Some(c) => {
let c = Closure::<dyn Fn()>::wrap(c);
self.0.set_onopen(Some(c.as_ref().unchecked_ref()));
c.forget();
}
None => self.0.set_onopen(None),
}
}
fn set_onmessage(&mut self, callback: Option<Box<dyn Fn(web_sys::MessageEvent) + Send>>) {
match callback {
Some(c) => {
let c = Closure::<dyn Fn(_)>::wrap(c);
self.0.set_onmessage(Some(c.as_ref().unchecked_ref()));
c.forget();
}
None => self.0.set_onmessage(None),
}
}
fn set_onerror(&mut self, callback: Option<Box<dyn Fn(web_sys::MessageEvent) + Send>>) {
match callback {
Some(c) => {
let c = Closure::<dyn Fn(_)>::wrap(c);
self.0.set_onerror(Some(c.as_ref().unchecked_ref()));
c.forget();
}
None => self.0.set_onerror(None),
}
}
fn buffered_amount(&mut self) -> u32 {
self.0.buffered_amount()
}
fn ready_state(&mut self) -> u16 {
self.0.ready_state()
}
fn send(&mut self, data: Bytes) -> Result<(), ()> {
let buff = js_sys::Uint8Array::from(data.as_ref()).buffer();
self.0.send_with_array_buffer(&buff).map_err(|_| ())
}
fn close(&mut self) -> Result<(), ()> {
self.0.close().map_err(|_| ())
}
}
pub struct WebSocket {
source: WSSource,
poll_sleep: bool,
poll_interval: Duration,
received: mpsc::UnboundedReceiver<Bytes>,
}
pub struct WSSource {
ws: Option<SendableWS>,
received_tx: mpsc::UnboundedSender<Bytes>,
notifier: Option<IONotifier>,
}
impl GenericSource for WSSource {
fn register(&mut self, notifier: IONotifier) -> Result<(), io::Error> {
if let Some(ws) = self.ws.as_mut() {
let received_tx = self.received_tx.clone();
let notifier_clone = notifier.clone();
let callback = Box::new(move |e: web_sys::MessageEvent| {
let data: Bytes = if let Ok(buff) = e.data().dyn_into::<js_sys::ArrayBuffer>() {
js_sys::Uint8Array::new(&buff).to_vec().into()
} else if let Ok(msg) = e.data().dyn_into::<js_sys::JsString>() {
match msg.as_string() {
Some(msg) => msg.as_bytes().to_vec().into(),
None => {
warn!("[WS] cannot convert to String, dropping");
return
}
}
} else {
warn!("[WS] received Blob data, expect Binary or String, dropping");
return
};
if !data.is_empty() {
received_tx.send(data).ok();
notifier_clone.blocking_notify(IOInterest::READABLE);
}
});
ws.set_onmessage(Some(callback));
notifier.blocking_notify(IOInterest::READABLE | IOInterest::WRITABLE);
self.notifier = Some(notifier);
}
Ok(())
}
fn deregister(&mut self) -> Result<(), io::Error> {
if let Some(ws) = self.ws.as_mut() {
ws.set_onmessage(None);
self.notifier = None;
}
Ok(())
}
}
impl WebSocket {
const MAX_BUFFERED_AMOUNT: u32 = 65536;
const INIT_POLL_INTERVAL: Duration = Duration::from_millis(1);
const MAX_POLL_INTERVAL: Duration = Duration::from_millis(1000);
const POLL_INC_FACTOR: u32 = 2;
pub async fn new(url: &str) -> Result<Self, JsValue> {
let (received_tx, received) = mpsc::unbounded_channel();
let mut ret = Self {
source: WSSource {
notifier: None,
ws: None,
received_tx,
},
received,
poll_sleep: false,
poll_interval: Self::INIT_POLL_INTERVAL,
};
match web_sys::WebSocket::new(url) {
Ok(ws) => {
ws.set_binary_type(web_sys::BinaryType::Arraybuffer);
ret.source.ws = Some(SendableWS::new(ws))
}
Err(_) => return Err(to_js_error(TransportError::BothTerminated)),
}
let mut ws = ret.source.ws.take().unwrap();
let state = ws.ready_state();
match state {
web_sys::WebSocket::CONNECTING => {
let connected = Arc::new(tokio::sync::Notify::new());
let c = connected.clone();
let erorred = Arc::new(tokio::sync::Notify::new());
let e = erorred.clone();
ws.set_onopen(Some(Box::new(move || c.notify_one())));
ws.set_onerror(Some(Box::new(move |_| e.notify_one())));
tokio::select! {
_ = connected.notified() => {
info!("[WS] connected");
ret.source.ws = Some(ws);
},
_ = erorred.notified() => {
error!("[WS] error while connecting");
ret.source.ws = Some(ws);
return Err(to_js_error(TransportError::BothTerminated));
},
}
}
web_sys::WebSocket::OPEN => {}
_ => return Err(to_js_error(TransportError::BothTerminated)),
}
Ok(ret)
}
fn schedule_poll(&self) {
if let Some(notifier) = &self.source.notifier {
let notifier = notifier.clone();
let interval = self.poll_interval;
wasm_bindgen_futures::spawn_local(async move {
sleep(interval).await;
notifier.notify(IOInterest::WRITABLE).await;
});
}
}
}
impl Read for WebSocket {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.try_recv() {
Ok(data) => {
let len = data.len();
buf[..len].copy_from_slice(&data);
Ok(len)
}
_ => Err(io::Error::new(io::ErrorKind::WouldBlock, "not ready")),
}
}
}
impl Write for WebSocket {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.try_send(Some(buf.to_vec().into())) {
Ok(true) => Ok(buf.len()),
Ok(false) => Err(io::Error::new(io::ErrorKind::WouldBlock, "not ready")),
Err(e) => Err(io::Error::new(io::ErrorKind::BrokenPipe, format!("{:?}", e))),
}
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl NonBlockingStream for WebSocket {
fn try_recv(&mut self) -> Result<Bytes, TransportError> {
self.received.try_recv().map_err(|_| TransportError::NotReady)
}
fn try_send(&mut self, data: Option<Bytes>) -> Result<bool, TransportError> {
match data {
Some(data) => {
if self.source.ws.as_mut().unwrap().buffered_amount() + (data.len() as u32) > Self::MAX_BUFFERED_AMOUNT
{
if self.poll_sleep {
if self.poll_interval < Self::MAX_POLL_INTERVAL {
self.poll_interval *= Self::POLL_INC_FACTOR
}
} else {
self.poll_sleep = true;
self.poll_interval = Self::INIT_POLL_INTERVAL;
}
self.schedule_poll();
return Err(TransportError::NotReady)
}
self.poll_sleep = false;
match self.source.ws.as_mut().unwrap().send(data) {
Ok(_) => Ok(true),
Err(_) => Err(TransportError::BothTerminated),
}
}
None => Ok(false),
}
}
fn shutdown(&mut self, _how: std::net::Shutdown) -> io::Result<()> {
if let Some(mut ws) = self.source.ws.take() {
match ws.close() {
Ok(_) => Ok(()),
Err(_) => Err(io::Error::new(io::ErrorKind::BrokenPipe, "ws is broken")),
}
} else {
Err(io::Error::new(io::ErrorKind::BrokenPipe, "ws is not initialized"))
}
}
fn source(&mut self) -> IOSource {
IOSource::Generic(&mut self.source)
}
}
#[derive(Clone)]
pub struct Factory;
impl Factory {
pub fn new(_bind_addr: Option<std::net::SocketAddr>) -> Self {
Self
}
}
#[async_trait]
impl AddressedStreamFactory for Factory {
async fn create_stream(&self, addr: &str) -> Option<Box<dyn NonBlockingStream>> {
Some(Box::new(WebSocket::new(addr).await.ok()?))
}
}