websocket-client-async 0.1.4

Asynchronous websocket client.
Documentation
use std::cmp::Ordering;
use std::io::ErrorKind;
use std::pin::Pin;
use std::task::Poll;

use anyhow::{anyhow, bail, Context, Result};
use futures_channel::mpsc::Receiver;
use futures_core::stream::Stream;
use futures_io::AsyncBufRead;
use futures_io::AsyncRead;
use futures_io::AsyncWrite;
use futures_util::StreamExt;
use js_sys::Uint8Array;
use wasm_bindgen::prelude::*;
use wasm_bindgen::JsCast;
use web_sys::{ErrorEvent, MessageEvent, WebSocket};

pub(crate) struct WebsocketIO {
    pub(crate) ws: WebSocket,
    pub(crate) reader: WebSocketReader,
    pub(crate) ws_url: String,
}

pub struct WebSocketReader {
    read_rx: Receiver<Uint8Array>,
    remaining: Vec<u8>,
}
pub(crate) struct WebSocketInner {
    pub(crate) ws: WebSocket,
    pub(crate) ws_url: String,
}

unsafe impl Send for WebSocketInner {}
unsafe impl Sync for WebSocketInner {}

impl WebsocketIO {
    #[inline]
    pub async fn new(addr: &str) -> Result<WebsocketIO> {
        WebsocketIO::new_inner(format!("ws://{}", addr)).await
    }
    #[inline]
    pub async fn new_wss(addr: &str) -> Result<WebsocketIO> {
        WebsocketIO::new_inner(format!("wss://{}", addr)).await
    }

    #[inline]
    async fn new_inner(url: String) -> Result<WebsocketIO> {
        let ws = {
            match WebSocket::new(&url) {
                Ok(ws) => ws,
                Err(err) => bail!("WebSocket new error:{:?}", err),
            }
        };

        ws.set_binary_type(web_sys::BinaryType::Arraybuffer);

        let buffer = 4;
        let (mut open_tx, mut open_rx) = futures_channel::mpsc::channel(1);
        let (read_tx, read_rx) = futures_channel::mpsc::channel(buffer);

        let onmessage_callback = Closure::wrap(Box::new(move |e: MessageEvent| {
            let mut read_tx = read_tx.clone();
            if let Ok(abuf) = e.data().dyn_into::<js_sys::ArrayBuffer>() {
                let array = js_sys::Uint8Array::new(&abuf);
                if !read_tx.is_closed() {
                    if let Err(err) = read_tx.start_send(array) {
                        log::error!("start_send error: {:?}", err);
                    }
                }
            } else if let Ok(blob) = e.data().dyn_into::<web_sys::Blob>() {
                let fr = web_sys::FileReader::new().expect("web_sys::FileReader::new() fail");
                let fr_c = fr.clone();
                let file_reader_load_end =
                    Closure::wrap(Box::new(move |_e: web_sys::ProgressEvent| {
                        let array = Uint8Array::new(
                            &fr_c.result().expect("web_sys::FileReader result() err"),
                        );
                        if !read_tx.is_closed() {
                            if let Err(err) = read_tx.start_send(array) {
                                log::error!("start_send error: {:?}", err);
                            }
                        }
                    }) as Box<dyn FnMut(web_sys::ProgressEvent)>);
                fr.set_onloadend(Some(file_reader_load_end.as_ref().unchecked_ref()));
                file_reader_load_end.forget();
                fr.read_as_array_buffer(&blob).expect("blob not readable");
            }
            else {
                log::error!("message event, received Unknown: {:?}", e.data());
            }
        }) as Box<dyn Fn(MessageEvent)>);

        let mut error_tx = open_tx.clone();
        let onerror_callback = Closure::wrap(Box::new(move |e: ErrorEvent| {
            log::error!("error event: {:?}", e);
            if !error_tx.is_closed() {
                let _= error_tx.start_send(0);
                error_tx.close_channel();
            }
        }) as Box<dyn FnMut(ErrorEvent)>);

        let onopen_callback =
            Closure::wrap(Box::new(move |_| {
                if !open_tx.is_closed() {
                    let _= open_tx.start_send(1);
                    open_tx.close_channel();
                }
            })
            as Box<dyn FnMut(JsValue)>);

        ws.set_onmessage(Some(onmessage_callback.as_ref().unchecked_ref()));
        onmessage_callback.forget();

        ws.set_onerror(Some(onerror_callback.as_ref().unchecked_ref()));
        onerror_callback.forget();

        ws.set_onopen(Some(onopen_callback.as_ref().unchecked_ref()));
        onopen_callback.forget();

        let reader = WebSocketReader {
            read_rx,
            remaining: Vec::new(),
        };

        if  open_rx.next().await.context("open_rx is none")? ==1 {
            open_rx.close();
            drop(open_rx);
            let ws_io = WebsocketIO {
                ws,
                reader,
                ws_url: url,
            };
            Ok(ws_io)
        }else{
            open_rx.close();
            drop(open_rx);
            bail!("connect to:{} fail",url)
        }
    }

    #[inline]
    pub fn split(self) -> (WebSocketReader, WebSocketInner) {
        let WebsocketIO { ws, reader, ws_url } = self;
        (reader, WebSocketInner { ws, ws_url })
    }
}

impl WebSocketReader {
    #[inline]
    fn write_remaining(&mut self, buf: &mut [u8]) -> usize {
        match self.remaining.len().cmp(&buf.len()) {
            Ordering::Less => {
                let amount = self.remaining.len();
                buf[0..amount].copy_from_slice(&self.remaining);
                self.remaining.clear();
                amount
            }
            Ordering::Equal => {
                buf.copy_from_slice(&self.remaining);
                self.remaining.clear();
                buf.len()
            }
            Ordering::Greater => {
                let amount = buf.len();
                buf.copy_from_slice(&self.remaining[..amount]);
                self.remaining.drain(0..amount);
                amount
            }
        }
    }
}

impl AsyncRead for WebSocketReader {
    #[inline]
    fn poll_read(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &mut [u8],
    ) -> Poll<std::io::Result<usize>> {
        if !self.remaining.is_empty() {
            return Poll::Ready(Ok(self.write_remaining(buf)));
        }

        let array = match Pin::new(&mut self.read_rx).poll_next(cx) {
            Poll::Ready(Some(item)) => item,
            Poll::Ready(None) => return Poll::Pending,
            Poll::Pending => return Poll::Pending,
        };

        let array_length = array.length() as usize;

        let read = match array_length.cmp(&buf.len()) {
            Ordering::Equal => {
                array.copy_to(buf);
                buf.len()
            }
            Ordering::Less => {
                array.copy_to(&mut buf[..array_length]);
                array_length
            }
            Ordering::Greater => {
                self.remaining.resize(array_length, 0);
                array.copy_to(self.as_mut().remaining.as_mut_slice());
                self.write_remaining(buf)
            }
        };

        Poll::Ready(Ok(read))
    }
}
impl AsyncBufRead for WebSocketReader {
    #[inline]
    fn poll_fill_buf(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<futures_io::Result<&[u8]>> {
        if !self.remaining.is_empty() {
            return Poll::Ready(Ok(self.get_mut().remaining.as_slice()));
        }

        let array = match Pin::new(&mut self.read_rx).poll_next(cx) {
            Poll::Ready(Some(item)) => item,
            Poll::Ready(None) => return Poll::Pending,
            Poll::Pending => return Poll::Pending,
        };

        self.remaining.extend(&array.to_vec());

        if self.remaining.is_empty() {
            return Poll::Pending;
        }
        Poll::Ready(Ok(self.get_mut().remaining.as_slice()))
    }
    #[inline]
    fn consume(mut self: std::pin::Pin<&mut Self>, amt: usize) {
        if self.remaining.len() == amt {
            self.remaining.clear();
            return;
        }
        self.remaining.drain(0..amt);
    }
}

impl AsyncWrite for WebSocketInner {
    #[inline]
    fn poll_write(
        self: std::pin::Pin<&mut Self>,
        _: &mut std::task::Context<'_>,
        buf: &[u8],
    ) -> Poll<std::io::Result<usize>> {
        match self.ws.send_with_u8_array(buf){
            Ok(_)=>{
                Poll::Ready(Ok(buf.len()))
            },
            Err(err)=>{
                Poll::Ready(Err(std::io::Error::new(ErrorKind::Other,anyhow!("{:?}",err))))
            }
        }
    }
    #[inline]
    fn poll_flush(
        self: std::pin::Pin<&mut Self>,
        _: &mut std::task::Context<'_>,
    ) -> Poll<std::io::Result<()>> {
        Poll::Ready(Ok(()))
    }
    #[inline]
    fn poll_close(
        self: std::pin::Pin<&mut Self>,
        _: &mut std::task::Context<'_>,
    ) -> Poll<std::io::Result<()>> {
        match self.ws.close(){
            Ok(_)=>{
                Poll::Ready(Ok(()))
            },
            Err(err)=>{
                Poll::Ready(Err(std::io::Error::new(ErrorKind::Other,anyhow!("{:?}",err))))
            }
        }
    }
}