use crate::{Error, Method, Request, Result};
use futures_channel::mpsc::UnboundedReceiver;
use futures_util::Stream;
use js_sys::Uint8Array;
use serde::Serialize;
use url::Url;
use worker_sys::ext::WebSocketExt;
#[cfg(not(feature = "http"))]
use crate::Fetch;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use wasm_bindgen::convert::FromWasmAbi;
use wasm_bindgen::prelude::Closure;
use wasm_bindgen::JsCast;
#[cfg(feature = "http")]
use wasm_bindgen_futures::JsFuture;
pub use crate::ws_events::*;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WebSocketPair {
pub client: WebSocket,
pub server: WebSocket,
}
unsafe impl Send for WebSocketPair {}
unsafe impl Sync for WebSocketPair {}
impl WebSocketPair {
pub fn new() -> Result<Self> {
let mut pair = worker_sys::WebSocketPair::new()?;
let client = pair.client()?.into();
let server = pair.server()?.into();
Ok(Self { client, server })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WebSocket {
socket: web_sys::WebSocket,
}
unsafe impl Send for WebSocket {}
unsafe impl Sync for WebSocket {}
impl WebSocket {
pub async fn connect(url: Url) -> Result<WebSocket> {
WebSocket::connect_with_protocols(url, None).await
}
pub async fn connect_with_protocols(
mut url: Url,
protocols: Option<Vec<&str>>,
) -> Result<WebSocket> {
let scheme: String = match url.scheme() {
"ws" => "http".into(),
"wss" => "https".into(),
scheme => scheme.into(),
};
url.set_scheme(&scheme).unwrap();
let mut req = Request::new(url.as_str(), Method::Get)?;
req.headers_mut()?.set("upgrade", "websocket")?;
match protocols {
None => {}
Some(v) => {
req.headers_mut()?
.set("Sec-WebSocket-Protocol", v.join(",").as_str())?;
}
}
#[cfg(not(feature = "http"))]
let res = Fetch::Request(req).send().await?;
#[cfg(feature = "http")]
let res: crate::Response = fetch_with_request_raw(req).await?.into();
match res.websocket() {
Some(ws) => Ok(ws),
None => Err(Error::RustError("server did not accept".into())),
}
}
pub fn accept(&self) -> Result<()> {
self.socket.accept().map_err(Error::from)
}
pub fn send<T: Serialize>(&self, data: &T) -> Result<()> {
let value = serde_json::to_string(data)?;
self.send_with_str(value.as_str())
}
pub fn send_with_str<S: AsRef<str>>(&self, data: S) -> Result<()> {
self.socket
.send_with_str(data.as_ref())
.map_err(Error::from)
}
pub fn send_with_bytes<D: AsRef<[u8]>>(&self, bytes: D) -> Result<()> {
let uint8_array = Uint8Array::from(bytes.as_ref());
self.socket.send_with_array_buffer(&uint8_array.buffer())?;
Ok(())
}
pub fn close<S: AsRef<str>>(&self, code: Option<u16>, reason: Option<S>) -> Result<()> {
if let Some((code, reason)) = code.zip(reason) {
self.socket
.close_with_code_and_reason(code, reason.as_ref())
} else if let Some(code) = code {
self.socket.close_with_code(code)
} else {
self.socket.close()
}
.map_err(Error::from)
}
fn add_event_handler<T: FromWasmAbi + 'static, F: FnMut(T) + 'static>(
&self,
r#type: &str,
fun: F,
) -> Result<Closure<dyn FnMut(T)>> {
let js_callback = Closure::wrap(Box::new(fun) as Box<dyn FnMut(T)>);
self.socket
.add_event_listener_with_callback(r#type, js_callback.as_ref().unchecked_ref())
.map_err(Error::from)?;
Ok(js_callback)
}
fn remove_event_handler<T: FromWasmAbi + 'static>(
&self,
r#type: &str,
js_callback: Closure<dyn FnMut(T)>,
) -> Result<()> {
self.socket
.remove_event_listener_with_callback(r#type, js_callback.as_ref().unchecked_ref())
.map_err(Error::from)
}
pub fn events(&self) -> Result<EventStream> {
let (tx, rx) = futures_channel::mpsc::unbounded::<Result<WebsocketEvent>>();
let tx = Rc::new(tx);
let close_closure = self.add_event_handler("close", {
let tx = tx.clone();
move |event: web_sys::CloseEvent| {
tx.unbounded_send(Ok(WebsocketEvent::Close(event.into())))
.unwrap();
}
})?;
let message_closure = self.add_event_handler("message", {
let tx = tx.clone();
move |event: web_sys::MessageEvent| {
tx.unbounded_send(Ok(WebsocketEvent::Message(event.into())))
.unwrap();
}
})?;
let error_closure =
self.add_event_handler("error", move |event: web_sys::ErrorEvent| {
let error = event.error();
tx.unbounded_send(Err(error.into())).unwrap();
})?;
Ok(EventStream {
ws: self,
rx,
closed: false,
closures: Some((message_closure, error_closure, close_closure)),
})
}
pub fn serialize_attachment<T: Serialize>(&self, value: T) -> Result<()> {
self.socket
.serialize_attachment(serde_wasm_bindgen::to_value(&value)?)
.map_err(Error::from)
}
pub fn deserialize_attachment<T: serde::de::DeserializeOwned>(&self) -> Result<Option<T>> {
let value = self.socket.deserialize_attachment().map_err(Error::from)?;
if value.is_null() || value.is_undefined() {
return Ok(None);
}
serde_wasm_bindgen::from_value::<T>(value)
.map(Some)
.map_err(Error::from)
}
}
type EvCallback<T> = Closure<dyn FnMut(T)>;
#[pin_project::pin_project(PinnedDrop)]
pub struct EventStream<'ws> {
ws: &'ws WebSocket,
#[pin]
rx: UnboundedReceiver<Result<WebsocketEvent>>,
closed: bool,
closures: Option<(
EvCallback<web_sys::MessageEvent>,
EvCallback<web_sys::ErrorEvent>,
EvCallback<web_sys::CloseEvent>,
)>,
}
impl<'ws> Stream for EventStream<'ws> {
type Item = Result<WebsocketEvent>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
if *this.closed {
return Poll::Ready(None);
}
let item = futures_util::ready!(this.rx.poll_next(cx));
if let Some(item) = &item {
if matches!(&item, Ok(WebsocketEvent::Close(_))) {
*this.closed = true;
}
}
Poll::Ready(item)
}
}
#[pin_project::pinned_drop]
impl PinnedDrop for EventStream<'_> {
fn drop(self: Pin<&'_ mut Self>) {
let this = self.project();
let (message_closure, error_closure, close_closure) =
std::mem::take(this.closures).expect("double drop on worker::EventStream");
this.ws
.remove_event_handler("message", message_closure)
.expect("could not remove message handler");
this.ws
.remove_event_handler("error", error_closure)
.expect("could not remove error handler");
this.ws
.remove_event_handler("close", close_closure)
.expect("could not remove close handler");
}
}
impl From<web_sys::WebSocket> for WebSocket {
fn from(socket: web_sys::WebSocket) -> Self {
Self { socket }
}
}
impl AsRef<web_sys::WebSocket> for WebSocket {
fn as_ref(&self) -> &web_sys::WebSocket {
&self.socket
}
}
pub mod ws_events {
use serde::de::DeserializeOwned;
use wasm_bindgen::JsValue;
use crate::Error;
#[derive(Debug, Clone)]
pub enum WebsocketEvent {
Message(MessageEvent),
Close(CloseEvent),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MessageEvent {
event: web_sys::MessageEvent,
}
impl From<web_sys::MessageEvent> for MessageEvent {
fn from(event: web_sys::MessageEvent) -> Self {
Self { event }
}
}
impl AsRef<web_sys::MessageEvent> for MessageEvent {
fn as_ref(&self) -> &web_sys::MessageEvent {
&self.event
}
}
impl MessageEvent {
fn data(&self) -> JsValue {
self.event.data()
}
pub fn text(&self) -> Option<String> {
let value = self.data();
value.as_string()
}
pub fn bytes(&self) -> Option<Vec<u8>> {
let value = self.data();
if value.is_object() {
Some(js_sys::Uint8Array::new(&value).to_vec())
} else {
None
}
}
pub fn json<T: DeserializeOwned>(&self) -> crate::Result<T> {
let text = match self.text() {
Some(text) => text,
None => return Err(Error::from("data of message event is not text")),
};
serde_json::from_str(&text).map_err(Error::from)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CloseEvent {
event: web_sys::CloseEvent,
}
impl CloseEvent {
pub fn reason(&self) -> String {
self.event.reason()
}
pub fn code(&self) -> u16 {
self.event.code()
}
pub fn was_clean(&self) -> bool {
self.event.was_clean()
}
}
impl From<web_sys::CloseEvent> for CloseEvent {
fn from(event: web_sys::CloseEvent) -> Self {
Self { event }
}
}
impl AsRef<web_sys::CloseEvent> for CloseEvent {
fn as_ref(&self) -> &web_sys::CloseEvent {
&self.event
}
}
}
#[cfg(feature = "http")]
async fn fetch_with_request_raw(request: crate::Request) -> Result<web_sys::Response> {
let req = request.inner();
let fut = {
let worker: web_sys::WorkerGlobalScope = js_sys::global().unchecked_into();
crate::send::SendFuture::new(JsFuture::from(worker.fetch_with_request(req)))
};
let resp = fut.await?;
Ok(resp.dyn_into()?)
}