#![warn(missing_docs)]
#[cfg(not(target_family = "wasm"))]
compile_error!("websocket-web requires a WebAssembly target");
mod closed;
mod standard;
mod stream;
mod util;
use futures_core::Stream;
use futures_sink::Sink;
use futures_util::{SinkExt, StreamExt};
use js_sys::{Reflect, Uint8Array};
use std::{
fmt, io,
io::ErrorKind,
mem,
pin::Pin,
rc::Rc,
task::{ready, Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite};
use wasm_bindgen::prelude::*;
pub use closed::{CloseCode, Closed, ClosedReason};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Interface {
Stream,
Standard,
}
impl Interface {
pub fn is_supported(&self) -> bool {
let global = js_sys::global();
match self {
Self::Stream => Reflect::has(&global, &JsValue::from_str("WebSocketStream")).unwrap_or_default(),
Self::Standard => Reflect::has(&global, &JsValue::from_str("WebSocket")).unwrap_or_default(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Msg {
Text(String),
Binary(Vec<u8>),
}
impl Msg {
pub const fn is_text(&self) -> bool {
matches!(self, Self::Text(_))
}
pub const fn is_binary(&self) -> bool {
matches!(self, Self::Binary(_))
}
pub fn to_vec(self) -> Vec<u8> {
match self {
Self::Text(text) => text.as_bytes().to_vec(),
Self::Binary(vec) => vec,
}
}
pub fn len(&self) -> usize {
match self {
Self::Text(text) => text.len(),
Self::Binary(vec) => vec.len(),
}
}
pub fn is_empty(&self) -> bool {
match self {
Self::Text(text) => text.is_empty(),
Self::Binary(vec) => vec.is_empty(),
}
}
}
impl fmt::Display for Msg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Text(text) => write!(f, "{text}"),
Self::Binary(binary) => write!(f, "{}", String::from_utf8_lossy(binary)),
}
}
}
impl From<Msg> for Vec<u8> {
fn from(msg: Msg) -> Self {
msg.to_vec()
}
}
impl AsRef<[u8]> for Msg {
fn as_ref(&self) -> &[u8] {
match self {
Self::Text(text) => text.as_bytes(),
Self::Binary(vec) => vec,
}
}
}
#[derive(Debug, Clone)]
pub struct WebSocketBuilder {
url: String,
protocols: Vec<String>,
interface: Option<Interface>,
send_buffer_size: Option<usize>,
receive_buffer_size: Option<usize>,
}
impl WebSocketBuilder {
pub fn new(url: impl AsRef<str>) -> Self {
Self {
url: url.as_ref().to_string(),
protocols: Vec::new(),
interface: None,
send_buffer_size: None,
receive_buffer_size: None,
}
}
pub fn set_interface(&mut self, interface: Interface) {
self.interface = Some(interface);
}
pub fn set_protocols<P>(&mut self, protocols: impl IntoIterator<Item = P>)
where
P: AsRef<str>,
{
self.protocols = protocols.into_iter().map(|s| s.as_ref().to_string()).collect();
}
pub fn set_send_buffer_size(&mut self, send_buffer_size: usize) {
self.send_buffer_size = Some(send_buffer_size);
}
pub fn set_receive_buffer_size(&mut self, receive_buffer_size: usize) {
self.receive_buffer_size = Some(receive_buffer_size);
}
pub async fn connect(self) -> io::Result<WebSocket> {
let interface = match self.interface {
Some(interface) => interface,
None if Interface::Stream.is_supported() => Interface::Stream,
None => Interface::Standard,
};
if !interface.is_supported() {
match interface {
Interface::Stream => {
return Err(io::Error::new(ErrorKind::Unsupported, "WebSocketStream not supported"))
}
Interface::Standard => {
return Err(io::Error::new(ErrorKind::Unsupported, "WebSocket not supported"))
}
}
}
match interface {
Interface::Stream => {
let (stream, info) = stream::Inner::new(self).await?;
Ok(WebSocket { inner: Inner::Stream(stream), info: Rc::new(info), read_buf: Vec::new() })
}
Interface::Standard => {
let (standard, info) = standard::Inner::new(self).await?;
Ok(WebSocket { inner: Inner::Standard(standard), info: Rc::new(info), read_buf: Vec::new() })
}
}
}
}
struct Info {
url: String,
protocol: String,
interface: Interface,
}
pub struct WebSocket {
inner: Inner,
info: Rc<Info>,
read_buf: Vec<u8>,
}
enum Inner {
Stream(stream::Inner),
Standard(standard::Inner),
}
impl fmt::Debug for WebSocket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WebSocket")
.field("url", &self.info.url)
.field("protocol", &self.protocol())
.field("interface", &self.interface())
.finish()
}
}
impl WebSocket {
pub async fn connect(url: impl AsRef<str>) -> io::Result<Self> {
WebSocketBuilder::new(url).connect().await
}
pub fn url(&self) -> &str {
&self.info.url
}
pub fn protocol(&self) -> &str {
&self.info.protocol
}
pub fn interface(&self) -> Interface {
self.info.interface
}
pub fn into_split(self) -> (WebSocketSender, WebSocketReceiver) {
let Self { inner, info, read_buf } = self;
match inner {
Inner::Stream(inner) => {
let (sender, receiver) = inner.into_split();
let sender = WebSocketSender { inner: SenderInner::Stream(sender), info: info.clone() };
let receiver = WebSocketReceiver { inner: ReceiverInner::Stream(receiver), info, read_buf };
(sender, receiver)
}
Inner::Standard(inner) => {
let (sender, receiver) = inner.into_split();
let sender = WebSocketSender { inner: SenderInner::Standard(sender), info: info.clone() };
let receiver =
WebSocketReceiver { inner: ReceiverInner::Standard(receiver), info, read_buf: Vec::new() };
(sender, receiver)
}
}
}
pub fn close(self) {
self.into_split().0.close();
}
#[track_caller]
pub fn close_with_reason(self, code: CloseCode, reason: &str) {
self.into_split().0.close_with_reason(code, reason);
}
pub fn closed(&self) -> Closed {
match &self.inner {
Inner::Stream(inner) => inner.closed(),
Inner::Standard(inner) => inner.closed(),
}
}
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
match &mut self.inner {
Inner::Stream(inner) => inner.sender.poll_ready_unpin(cx),
Inner::Standard(inner) => inner.sender.poll_ready_unpin(cx),
}
}
fn start_send(mut self: Pin<&mut Self>, item: &JsValue) -> Result<(), io::Error> {
match &mut self.inner {
Inner::Stream(inner) => inner.sender.start_send_unpin(item),
Inner::Standard(inner) => inner.sender.start_send_unpin(item),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
match &mut self.inner {
Inner::Stream(inner) => inner.sender.poll_flush_unpin(cx),
Inner::Standard(inner) => inner.sender.poll_flush_unpin(cx),
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
match &mut self.inner {
Inner::Stream(inner) => inner.sender.poll_close_unpin(cx),
Inner::Standard(inner) => inner.sender.poll_close_unpin(cx),
}
}
}
impl Sink<&str> for WebSocket {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: &str) -> Result<(), Self::Error> {
self.start_send(&JsValue::from_str(item))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_close(cx)
}
}
impl Sink<String> for WebSocket {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
self.start_send(&JsValue::from_str(&item))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_close(cx)
}
}
impl Sink<&[u8]> for WebSocket {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
self.start_send(&Uint8Array::from(item))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_close(cx)
}
}
impl Sink<Vec<u8>> for WebSocket {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
self.start_send(&Uint8Array::from(&item[..]))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_close(cx)
}
}
impl Sink<Msg> for WebSocket {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Msg) -> Result<(), Self::Error> {
match item {
Msg::Text(text) => self.start_send(&JsValue::from_str(&text)),
Msg::Binary(vec) => self.start_send(&Uint8Array::from(&vec[..])),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_close(cx)
}
}
impl AsyncWrite for WebSocket {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
ready!(self.as_mut().poll_ready(cx))?;
self.start_send(&Uint8Array::from(buf))?;
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.poll_close(cx)
}
}
impl Stream for WebSocket {
type Item = io::Result<Msg>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match &mut self.inner {
Inner::Stream(inner) => inner.receiver.poll_next_unpin(cx),
Inner::Standard(inner) => inner.receiver.poll_next_unpin(cx),
}
}
}
impl AsyncRead for WebSocket {
fn poll_read(
mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut tokio::io::ReadBuf,
) -> Poll<io::Result<()>> {
while self.read_buf.is_empty() {
let Some(msg) = ready!(self.as_mut().poll_next(cx)?) else { return Poll::Ready(Ok(())) };
self.read_buf = msg.to_vec();
}
let part = if buf.remaining() < self.read_buf.len() {
let rem = self.read_buf.split_off(buf.remaining());
mem::replace(&mut self.read_buf, rem)
} else {
mem::take(&mut self.read_buf)
};
buf.put_slice(&part);
Poll::Ready(Ok(()))
}
}
pub struct WebSocketSender {
inner: SenderInner,
info: Rc<Info>,
}
enum SenderInner {
Stream(stream::Sender),
Standard(standard::Sender),
}
impl fmt::Debug for WebSocketSender {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WebSocketSender")
.field("url", &self.info.url)
.field("protocol", &self.protocol())
.field("interface", &self.interface())
.finish()
}
}
impl WebSocketSender {
pub fn url(&self) -> &str {
&self.info.url
}
pub fn protocol(&self) -> &str {
&self.info.protocol
}
pub fn interface(&self) -> Interface {
self.info.interface
}
pub fn close(self) {
self.close_with_reason(CloseCode::NormalClosure, "");
}
#[track_caller]
pub fn close_with_reason(self, code: CloseCode, reason: &str) {
if !code.is_valid() {
panic!("WebSocket close code {code} is invalid");
}
match self.inner {
SenderInner::Stream(sender) => sender.close(code.into(), reason),
SenderInner::Standard(sender) => sender.close(code.into(), reason),
}
}
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
match &mut self.inner {
SenderInner::Stream(inner) => inner.poll_ready_unpin(cx),
SenderInner::Standard(inner) => inner.poll_ready_unpin(cx),
}
}
fn start_send(mut self: Pin<&mut Self>, item: &JsValue) -> Result<(), io::Error> {
match &mut self.inner {
SenderInner::Stream(inner) => inner.start_send_unpin(item),
SenderInner::Standard(inner) => inner.start_send_unpin(item),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
match &mut self.inner {
SenderInner::Stream(inner) => inner.poll_flush_unpin(cx),
SenderInner::Standard(inner) => inner.poll_flush_unpin(cx),
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
match &mut self.inner {
SenderInner::Stream(inner) => inner.poll_close_unpin(cx),
SenderInner::Standard(inner) => inner.poll_close_unpin(cx),
}
}
}
impl Sink<&str> for WebSocketSender {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: &str) -> Result<(), Self::Error> {
self.start_send(&JsValue::from_str(item))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_close(cx)
}
}
impl Sink<String> for WebSocketSender {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
self.start_send(&JsValue::from_str(&item))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_close(cx)
}
}
impl Sink<&[u8]> for WebSocketSender {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: &[u8]) -> Result<(), Self::Error> {
self.start_send(&Uint8Array::from(item))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_close(cx)
}
}
impl Sink<Vec<u8>> for WebSocketSender {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
self.start_send(&Uint8Array::from(&item[..]))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_close(cx)
}
}
impl Sink<Msg> for WebSocketSender {
type Error = io::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Msg) -> Result<(), Self::Error> {
match item {
Msg::Text(text) => self.start_send(&JsValue::from_str(&text)),
Msg::Binary(vec) => self.start_send(&Uint8Array::from(&vec[..])),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.poll_close(cx)
}
}
impl AsyncWrite for WebSocketSender {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
ready!(self.as_mut().poll_ready(cx))?;
self.start_send(&Uint8Array::from(buf))?;
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.poll_close(cx)
}
}
pub struct WebSocketReceiver {
inner: ReceiverInner,
info: Rc<Info>,
read_buf: Vec<u8>,
}
enum ReceiverInner {
Stream(stream::Receiver),
Standard(standard::Receiver),
}
impl fmt::Debug for WebSocketReceiver {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("WebSocketReceiver")
.field("url", &self.info.url)
.field("protocol", &self.protocol())
.field("interface", &self.interface())
.finish()
}
}
impl WebSocketReceiver {
pub fn url(&self) -> &str {
&self.info.url
}
pub fn protocol(&self) -> &str {
&self.info.protocol
}
pub fn interface(&self) -> Interface {
self.info.interface
}
}
impl Stream for WebSocketReceiver {
type Item = io::Result<Msg>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match &mut self.inner {
ReceiverInner::Stream(inner) => inner.poll_next_unpin(cx),
ReceiverInner::Standard(inner) => inner.poll_next_unpin(cx),
}
}
}
impl AsyncRead for WebSocketReceiver {
fn poll_read(
mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut tokio::io::ReadBuf,
) -> Poll<io::Result<()>> {
while self.read_buf.is_empty() {
let Some(msg) = ready!(self.as_mut().poll_next(cx)?) else { return Poll::Ready(Ok(())) };
self.read_buf = msg.to_vec();
}
let part = if buf.remaining() < self.read_buf.len() {
let rem = self.read_buf.split_off(buf.remaining());
mem::replace(&mut self.read_buf, rem)
} else {
mem::take(&mut self.read_buf)
};
buf.put_slice(&part);
Poll::Ready(Ok(()))
}
}