websocket-web 0.1.9

WebSockets on the web 🕸️ — WebSocket support in a JavaScript runtime environment, usually a web browser.
Documentation
//! WebSocket streams API.

use futures_core::Stream;
use futures_sink::Sink;
use futures_util::FutureExt;
use js_sys::{Array, Object, Promise, Reflect, Uint8Array};
use std::{
    cell::Cell,
    io,
    io::ErrorKind,
    ops::Deref,
    pin::Pin,
    rc::Rc,
    task::{ready, Context, Poll},
};
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
use web_sys::{
    ReadableStream, ReadableStreamDefaultReader, ReadableStreamReadResult, WritableStream,
    WritableStreamDefaultWriter,
};

use crate::{
    closed::{CloseCode, Closed, ClosedReason},
    util::{js_err, js_err_msg},
    Info, Interface, Msg, WebSocketBuilder,
};

const DEFAULT_SEND_BUFFER_SIZE: usize = 4_194_304;

#[wasm_bindgen]
extern "C" {
    type WebSocketStream;

    #[wasm_bindgen(constructor, catch)]
    fn new(url: &str, options: &JsValue) -> Result<WebSocketStream, JsValue>;

    #[wasm_bindgen(method, getter, catch)]
    async fn opened(this: &WebSocketStream) -> Result<WebSocketStreamOpened, JsValue>;

    #[wasm_bindgen(method, getter)]
    fn closed(this: &WebSocketStream) -> Promise;

    #[wasm_bindgen(method, catch)]
    fn close(this: &WebSocketStream, options: &JsValue) -> Result<(), JsValue>;

    #[wasm_bindgen(method, getter)]
    fn url(this: &WebSocketStream) -> String;
}

#[wasm_bindgen]
extern "C" {
    type WebSocketStreamOpened;

    #[wasm_bindgen(method, getter)]
    fn extensions(this: &WebSocketStreamOpened) -> String;

    #[wasm_bindgen(method, getter)]
    fn protocol(this: &WebSocketStreamOpened) -> String;

    #[wasm_bindgen(method, getter)]
    fn readable(this: &WebSocketStreamOpened) -> ReadableStream;

    #[wasm_bindgen(method, getter)]
    fn writable(this: &WebSocketStreamOpened) -> WritableStream;
}

#[wasm_bindgen]
extern "C" {
    type WebSocketStreamClosed;

    #[wasm_bindgen(method, getter)]
    fn closeCode(this: &WebSocketStreamClosed) -> u32;

    #[wasm_bindgen(method, getter)]
    fn reason(this: &WebSocketStreamClosed) -> String;
}

struct Guard {
    socket: WebSocketStream,
    closed: Cell<bool>,
}

impl Guard {
    fn new(socket: WebSocketStream) -> Self {
        Self { socket, closed: Cell::new(false) }
    }
}

impl Deref for Guard {
    type Target = WebSocketStream;
    fn deref(&self) -> &Self::Target {
        &self.socket
    }
}

impl Drop for Guard {
    fn drop(&mut self) {
        if !self.closed.get() {
            self.socket.close(&JsValue::null()).unwrap();
        }
    }
}

pub struct Inner {
    socket: Rc<Guard>,
    pub(crate) sender: Sender,
    pub(crate) receiver: Receiver,
}

impl Inner {
    pub async fn new(builder: WebSocketBuilder) -> io::Result<(Self, Info)> {
        // Create WebSocketStream.
        let options = Object::new();
        if !builder.protocols.is_empty() {
            let arr = Array::new();
            for proto in builder.protocols {
                arr.push(&JsValue::from_str(&proto));
            }
            Reflect::set(&options, &JsValue::from_str("protocols"), &arr).unwrap();
        }
        let socket = WebSocketStream::new(&builder.url, &JsValue::from(options))
            .map_err(|js| js_err(ErrorKind::ConnectionRefused, &js))?;
        let socket = Rc::new(Guard::new(socket));

        // Open WebSocket connection.
        let opened = match socket.opened().await {
            Ok(opened) => opened,
            Err(js) => return Err(js_err(ErrorKind::ConnectionRefused, &js)),
        };

        // Obtain reader and writer.
        let writer = opened.writable().get_writer().unwrap();
        let reader = opened.readable().get_reader().dyn_into::<ReadableStreamDefaultReader>().unwrap();

        Ok((
            Self {
                socket: socket.clone(),
                sender: Sender::new(socket.clone(), writer, builder.send_buffer_size),
                receiver: Receiver::new(socket.clone(), reader),
            },
            Info { url: socket.url(), protocol: opened.protocol(), interface: Interface::Stream },
        ))
    }

    pub fn closed(&self) -> Closed {
        let closed = self.socket.closed();
        Closed(
            async move {
                match JsFuture::from(closed).await {
                    Ok(c) => {
                        let c: WebSocketStreamClosed = c.unchecked_into();
                        ClosedReason {
                            code: CloseCode::from(c.closeCode() as u16),
                            reason: c.reason(),
                            was_clean: true,
                        }
                    }
                    Err(err) => ClosedReason {
                        code: CloseCode::AbnormalClosure,
                        reason: js_err_msg(&err).unwrap_or_default(),
                        was_clean: false,
                    },
                }
            }
            .boxed_local(),
        )
    }

    pub fn into_split(self) -> (Sender, Receiver) {
        (self.sender, self.receiver)
    }
}

pub struct Sender {
    socket: Rc<Guard>,
    writer: WritableStreamDefaultWriter,
    writing: Option<JsFuture>,
    closing: Option<JsFuture>,
    buffered: usize,
    send_buffer_size: usize,
}

impl Sender {
    fn new(socket: Rc<Guard>, writer: WritableStreamDefaultWriter, send_buffer_size: Option<usize>) -> Self {
        Self {
            socket,
            writer,
            writing: None,
            closing: None,
            buffered: 0,
            send_buffer_size: send_buffer_size.unwrap_or(DEFAULT_SEND_BUFFER_SIZE),
        }
    }

    #[track_caller]
    pub fn close(self, code: u16, reason: &str) {
        let options = Object::new();
        Reflect::set(&options, &JsValue::from("closeCode"), &JsValue::from(code)).unwrap();
        Reflect::set(&options, &JsValue::from("reason"), &JsValue::from_str(reason)).unwrap();
        self.socket.close(&options).unwrap();
        self.socket.closed.set(true);
    }
}

impl Sink<(&JsValue, usize)> for Sender {
    type Error = io::Error;

    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
        let Some(writing) = &mut self.writing else {
            return Poll::Ready(Ok(()));
        };

        let res = match ready!(writing.poll_unpin(cx)) {
            Ok(_) => Ok(()),
            Err(err) => Err(js_err(ErrorKind::ConnectionReset, &err)),
        };

        self.writing = None;
        self.buffered = 0;

        Poll::Ready(res)
    }

    fn start_send(mut self: Pin<&mut Self>, (item, len): (&JsValue, usize)) -> Result<(), Self::Error> {
        if self.writing.is_some() {
            panic!("WebSocket not ready for sending");
        }

        let promise = self.writer.write_with_chunk(item);
        self.buffered += len;

        if self.buffered >= self.send_buffer_size {
            self.writing = Some(JsFuture::from(promise));
        }

        Ok(())
    }

    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
        if self.closing.is_none() {
            self.closing = Some(JsFuture::from(self.writer.close()));
        }

        let Some(closing) = &mut self.closing else { unreachable!() };
        let res = match ready!(closing.poll_unpin(cx)) {
            Ok(_) => Ok(()),
            Err(err) => Err(js_err(ErrorKind::ConnectionReset, &err)),
        };

        self.closing = None;
        Poll::Ready(res)
    }
}

impl Drop for Sender {
    fn drop(&mut self) {
        // empty
    }
}

pub struct Receiver {
    _socket: Rc<Guard>,
    reader: ReadableStreamDefaultReader,
    reading: Option<JsFuture>,
}

impl Receiver {
    fn new(socket: Rc<Guard>, reader: ReadableStreamDefaultReader) -> Self {
        Self { _socket: socket, reader, reading: None }
    }
}

impl Stream for Receiver {
    type Item = io::Result<Msg>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
        if self.reading.is_none() {
            self.reading = Some(JsFuture::from(self.reader.read()));
        }

        let Some(reading) = &mut self.reading else { unreachable!() };
        let res = match ready!(reading.poll_unpin(cx)) {
            Ok(data) => {
                let result: ReadableStreamReadResult = data.unchecked_into();
                if result.get_done().unwrap_or_default() {
                    self.reading = None;
                    None
                } else {
                    let chunk = result.get_value();
                    // Pipeline: start next read immediately to reduce per-message latency.
                    self.reading = Some(JsFuture::from(self.reader.read()));
                    if chunk.is_string() {
                        Some(Ok(Msg::Text(chunk.as_string().unwrap())))
                    } else {
                        Some(Ok(Msg::Binary(Uint8Array::new(&chunk).to_vec())))
                    }
                }
            }
            Err(err) => {
                self.reading = None;
                Some(Err(js_err(ErrorKind::ConnectionReset, &err)))
            }
        };

        Poll::Ready(res)
    }
}

impl Drop for Receiver {
    fn drop(&mut self) {
        // empty
    }
}