use std::{
io::ErrorKind,
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use js_sys::Uint8Array;
use log::{debug, error, warn};
use wasm_bindgen::{closure::Closure, JsCast};
use wasm_bindgen_futures::spawn_local;
use web_sys::{BinaryType, CloseEvent, ErrorEvent, MessageEvent, WebSocket};
use crate::{ClientSocket, NetcodeTransportError, HTTP_CONNECT_REQ};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct WebSocketClientConfig {
pub server_url: url::Url,
}
impl WebSocketClientConfig {
pub fn server_address(&self) -> Result<SocketAddr, anyhow::Error> {
let host = self
.server_url
.host()
.ok_or_else(|| std::io::Error::other("WebSocketClientConfig url does not have a host"))?;
let port = self.server_url.port().unwrap_or_default();
match host {
url::Host::Domain(_) => {
Err(std::io::Error::other("WebSocketClientConfig url is a domain but a SocketAddr was expected").into())
}
url::Host::Ipv4(ipv4) => Ok((ipv4, port).into()),
url::Host::Ipv6(ipv6) => Ok((ipv6, port).into()),
}
}
}
#[derive(Debug)]
pub struct WebSocketClient {
server_url: url::Url,
server_address: SocketAddr,
server_has_tls: bool,
connect_req_sender: async_channel::Sender<Vec<u8>>,
incoming_receiver: async_channel::Receiver<Vec<u8>>,
close_sender: async_channel::Sender<()>,
outgoing_sender: async_channel::Sender<Vec<u8>>,
closed: Arc<AtomicBool>,
is_disconnected: bool,
sent_connection_request: bool,
}
impl WebSocketClient {
pub fn new(config: WebSocketClientConfig) -> Result<Self, anyhow::Error> {
let server_address = config.server_address().unwrap_or_else(|_| SocketAddr::from(([0, 0, 0, 0], 0)));
let mut server_url = config.server_url.clone();
let server_has_tls = match server_url.scheme() {
"wss" => true,
"ws" => false,
other => {
return Err(std::io::Error::other(format!(
"failed setting up websocket client, server url has '{other}' scheme instead of \
'ws' or 'wss'"
))
.into());
}
};
let (close_sender, close_receiver) = async_channel::unbounded::<()>();
let (incoming_sender, incoming_receiver) = async_channel::unbounded::<Vec<u8>>();
let (connect_req_sender, connect_req_receiver) = async_channel::bounded::<Vec<u8>>(1);
let (outgoing_sender, outgoing_receiver) = async_channel::unbounded::<Vec<u8>>();
let closed = Arc::new(AtomicBool::new(false));
let inner_close_sender = close_sender.clone();
let inner_closed = closed.clone();
spawn_local(async move {
let Ok(connection_req) = connect_req_receiver.recv().await else {
inner_closed.store(true, Ordering::Relaxed);
return;
};
let connect_msg_ser = urlencoding::encode_binary(&connection_req);
server_url.set_query(Some(format!("{}={}", HTTP_CONNECT_REQ, &connect_msg_ser).as_str()));
let ws = match WebSocket::new(server_url.as_str()) {
Ok(ws) => ws,
Err(err) => {
warn!(
"failed initializing websocket client with server url \"{}\": {err:?}",
server_url.as_str()
);
inner_closed.store(true, Ordering::Relaxed);
return;
}
};
ws.set_binary_type(BinaryType::Arraybuffer);
let message_closed = inner_closed.clone();
let message_close_sender = inner_close_sender.clone();
let on_message_callback = Closure::<dyn FnMut(_)>::new(move |e: MessageEvent| {
let msg = Uint8Array::new(&e.data()).to_vec();
if incoming_sender.try_send(msg).is_err() {
message_closed.store(true, Ordering::Relaxed);
let _ = message_close_sender.try_send(());
}
});
let close_closed = inner_closed.clone();
let close_close_sender = inner_close_sender.clone();
let on_close_callback = Closure::<dyn FnMut(_)>::new(move |_: CloseEvent| {
close_closed.store(true, Ordering::Relaxed);
let _ = close_close_sender.try_send(());
});
ws.set_onmessage(Some(on_message_callback.as_ref().unchecked_ref()));
ws.set_onclose(Some(on_close_callback.as_ref().unchecked_ref()));
on_message_callback.forget();
on_close_callback.forget();
let (open_sx, open_rx) = futures_channel::oneshot::channel();
let on_open_callback = {
let mut open_sx = Some(open_sx);
Closure::wrap(Box::new(move |_event| {
open_sx.take().map(|open_sx| open_sx.send(()));
}) as Box<dyn FnMut(web_sys::Event)>)
};
let (err_sx, err_rx) = futures_channel::oneshot::channel();
let on_error_callback = {
let mut err_sx = Some(err_sx);
Closure::wrap(Box::new(move |_error_event| {
err_sx.take().map(|err_sx| err_sx.send(()));
}) as Box<dyn FnMut(ErrorEvent)>)
};
ws.set_onerror(Some(on_error_callback.as_ref().unchecked_ref()));
ws.set_onopen(Some(on_open_callback.as_ref().unchecked_ref()));
on_error_callback.forget();
on_open_callback.forget();
let result = futures_util::future::select(open_rx, err_rx).await;
ws.set_onopen(None);
ws.set_onerror(None);
let ws = match result {
futures_util::future::Either::Left((_, _)) => ws,
futures_util::future::Either::Right((_, _)) => {
inner_closed.store(true, Ordering::Relaxed);
let _ = inner_close_sender.try_send(());
return;
}
};
let err_closed = inner_closed.clone();
let err_close_sender = inner_close_sender.clone();
let on_error_callback = Closure::<dyn FnMut(_)>::new(move |e: ErrorEvent| {
warn!("WebSocket connection error {}", e.message());
err_closed.store(true, Ordering::Relaxed);
let _ = err_close_sender.try_send(());
});
ws.set_onerror(Some(on_error_callback.as_ref().unchecked_ref()));
on_error_callback.forget();
let ws_clone = ws.clone();
let send_closed = inner_closed.clone();
spawn_local(async move {
while let Ok(msg) = outgoing_receiver.recv().await {
if ws_clone.ready_state() != 1 {
warn!("Tried to send packet through closed websocket connection");
break;
}
if ws_clone.send_with_u8_array(&msg).is_err() {
let _ = ws_clone.close();
break;
}
}
send_closed.store(true, Ordering::Relaxed);
let _ = inner_close_sender.try_send(());
});
spawn_local(async move {
let _ = close_receiver.recv().await;
let _ = ws.close();
inner_closed.store(true, Ordering::Relaxed);
});
});
Ok(Self {
server_url: config.server_url,
server_address,
server_has_tls,
connect_req_sender,
incoming_receiver,
close_sender,
outgoing_sender,
closed,
is_disconnected: false,
sent_connection_request: false,
})
}
pub fn is_disconnected(&self) -> bool {
self.is_disconnected || self.closed.load(Ordering::Relaxed)
}
pub fn server_url(&self) -> &url::Url {
&self.server_url
}
pub fn server_address(&self) -> SocketAddr {
self.server_address
}
pub fn disconnect(&mut self) {
let _ = self.close_sender.send(());
self.is_disconnected = true;
self.closed.store(true, Ordering::Relaxed);
}
}
impl Drop for WebSocketClient {
fn drop(&mut self) {
self.disconnect();
}
}
impl ClientSocket for WebSocketClient {
fn is_encrypted(&self) -> bool {
self.server_has_tls
}
fn is_reliable(&self) -> bool {
true
}
fn addr(&self) -> std::io::Result<SocketAddr> {
Err(std::io::Error::from(ErrorKind::AddrNotAvailable))
}
fn is_closed(&mut self) -> bool {
self.is_disconnected()
}
fn close(&mut self) {
self.disconnect()
}
fn preupdate(&mut self) {
if !self.is_disconnected && self.closed.load(Ordering::Relaxed) {
self.disconnect();
}
}
fn try_recv(&mut self, buffer: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
if self.is_closed() {
return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
}
let Ok(packet) = self.incoming_receiver.try_recv() else {
return Err(std::io::Error::from(ErrorKind::WouldBlock));
};
if packet.len() > buffer.len() {
return Err(std::io::Error::from(ErrorKind::InvalidData));
}
buffer[..packet.len()].copy_from_slice(&packet[..]);
Ok((packet.len(), self.server_address()))
}
fn postupdate(&mut self) {}
fn send(&mut self, addr: SocketAddr, packet: &[u8]) -> Result<(), NetcodeTransportError> {
if self.is_closed() {
return Err(std::io::Error::from(ErrorKind::ConnectionAborted).into());
}
if addr != self.server_address() {
error!("tried sending packet to invalid WebSocket server {}", addr);
self.close();
return Err(std::io::Error::from(ErrorKind::AddrNotAvailable).into());
}
if !self.sent_connection_request {
let packet_type = renetcode2::Packet::packet_type_from_buffer(packet)?;
if packet_type != renetcode2::PacketType::ConnectionRequest {
debug!(
"ignoring {:?}, the first packet sent to a webSocket client must be a connection request",
packet_type
);
return Ok(());
}
let mut data = Vec::default();
data.extend_from_slice(packet);
let _ = self.connect_req_sender.try_send(data);
self.sent_connection_request = true;
return Ok(());
}
if let Err(_) = self.outgoing_sender.try_send(packet.into()) {
self.close();
return Err(std::io::Error::from(ErrorKind::ConnectionAborted).into());
}
Ok(())
}
}