use std::{
io::ErrorKind,
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use fragile::Fragile;
use js_sys::{Promise, Uint8Array};
use log::{debug, error, warn};
use renetcode2::NETCODE_MAX_PACKET_BYTES;
use send_wrapper::SendWrapper;
use wasm_bindgen::{prelude::Closure, JsValue};
use wasm_bindgen_futures::{spawn_local, JsFuture};
use web_sys::{ReadableStreamDefaultReader, WritableStreamDefaultWriter};
use crate::{ClientSocket, NetcodeTransportError, ServerCertHash, WebServerDestination, HTTP_CONNECT_REQ};
use super::bindings::{
ReadableStreamDefaultReadResult, WebTransport, WebTransportCongestionControl, WebTransportError, WebTransportHash, WebTransportOptions,
};
#[derive(Debug, Default, Copy, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum CongestionControl {
Default,
Throughput,
#[default]
LowLatency,
}
impl CongestionControl {
fn to_wt(&self) -> WebTransportCongestionControl {
match self {
Self::Default => WebTransportCongestionControl::Default,
Self::Throughput => WebTransportCongestionControl::Throughput,
Self::LowLatency => WebTransportCongestionControl::LowLatency,
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct WebTransportClientConfig {
pub server_dest: WebServerDestination,
pub congestion_control: CongestionControl,
pub server_cert_hashes: Vec<ServerCertHash>,
}
impl WebTransportClientConfig {
pub fn new(server_dest: impl Into<WebServerDestination>) -> Self {
Self {
server_dest: server_dest.into(),
congestion_control: CongestionControl::default(),
server_cert_hashes: Vec::default(),
}
}
pub fn new_with_certs(server_dest: impl Into<WebServerDestination>, server_cert_hashes: Vec<ServerCertHash>) -> Self {
Self {
server_dest: server_dest.into(),
congestion_control: CongestionControl::default(),
server_cert_hashes,
}
}
}
impl WebTransportClientConfig {
pub fn wt_options(&self) -> WebTransportOptions {
let mut options = WebTransportOptions::new();
options.congestion_control(self.congestion_control.to_wt()).require_unreliable(true);
if self.server_cert_hashes.len() > 0 {
let cert_hashes = self
.server_cert_hashes
.iter()
.map(|cert| {
let mut hash = WebTransportHash::new();
hash.algorithm("sha-256");
hash.value(&js_sys::Uint8Array::from(cert.hash.as_ref()));
wasm_bindgen::JsValue::from(hash)
})
.collect::<js_sys::Array>();
options.server_certificate_hashes(&cert_hashes);
}
options
}
}
pub struct WebTransportClient {
server_url: url::Url,
server_address: SocketAddr,
connect_req_sender: async_channel::Sender<Vec<u8>>,
incoming_receiver: async_channel::Receiver<Vec<u8>>,
close_sender: async_channel::Sender<()>,
writer_receiver: async_channel::Receiver<Fragile<WritableStreamDefaultWriter>>,
writer: Option<Fragile<WritableStreamDefaultWriter>>,
closed: Arc<AtomicBool>,
is_disconnected: bool,
sent_connection_request: bool,
}
impl WebTransportClient {
pub fn new(config: WebTransportClientConfig) -> Self {
let options = config.wt_options();
let (close_sender, close_receiver) = async_channel::unbounded::<()>();
let (incoming_sender, incoming_receiver) = async_channel::unbounded::<Vec<u8>>();
let (connect_req_sender, connect_req_receiver) = async_channel::bounded::<Vec<u8>>(1);
let (writer_sender, writer_receiver) = async_channel::bounded::<Fragile<WritableStreamDefaultWriter>>(1);
let closed = Arc::new(AtomicBool::new(false));
let inner_server_dest = config.server_dest.clone();
let inner_close_sender = close_sender.clone();
let inner_closed = closed.clone();
spawn_local(async move {
let Ok(connection_req) = connect_req_receiver.recv().await else {
inner_closed.store(true, Ordering::Relaxed);
return;
};
let mut url: url::Url = inner_server_dest
.clone()
.try_into()
.expect("could not convert server destination to url");
let connect_msg_ser = urlencoding::encode_binary(&connection_req);
url.set_query(Some(format!("{}={}", HTTP_CONNECT_REQ, &connect_msg_ser).as_str()));
let web_transport = match Self::init_web_transport(url.as_str(), options).await {
Ok(web_transport) => web_transport,
Err(err) => {
let _ = inner_close_sender.send(()).await;
warn!("failed setting up web transport client {:?}", err);
return;
}
};
let close_callback_handle: Closure<dyn FnMut(JsValue)> = Self::get_close_callback(inner_close_sender);
let _ = web_transport.closed().then(&close_callback_handle).catch(&close_callback_handle);
let writer: WritableStreamDefaultWriter = match web_transport.datagrams().writable().get_writer() {
Ok(writer) => writer,
Err(err) => {
web_transport.close();
warn!("failed setting up web transport client {:?}", err);
return;
}
};
if !inner_closed.load(Ordering::Relaxed) {
let writer = Fragile::new(writer);
let _ = writer_sender.try_send(writer);
} else {
handle_promise(writer.close());
web_transport.close();
return;
}
let reader = web_transport.datagrams().readable().get_reader();
let reader: ReadableStreamDefaultReader = JsValue::from(reader).into();
let reader_closed = inner_closed.clone();
Self::reader_task(reader, reader_closed, incoming_sender);
let _ = close_receiver.recv().await;
inner_closed.store(true, Ordering::Relaxed);
web_transport.close();
});
Self {
server_url: config
.server_dest
.clone()
.try_into()
.expect("could not convert server destination to url"),
server_address: config.server_dest.into(),
connect_req_sender,
incoming_receiver,
close_sender,
writer_receiver,
writer: None,
closed,
is_disconnected: false,
sent_connection_request: false,
}
}
pub fn is_disconnected(&self) -> bool {
self.is_disconnected || self.closed.load(Ordering::Relaxed)
}
pub fn server_url(&self) -> &url::Url {
&self.server_url
}
pub fn server_address(&self) -> SocketAddr {
self.server_address
}
pub fn disconnect(&mut self) {
let _ = self.close_sender.send(());
if let Ok(writer) = self.writer_receiver.try_recv() {
self.writer = Some(writer);
}
if let Some(writer) = self.writer.as_ref().map(Fragile::get) {
handle_promise(writer.close());
}
self.writer = None;
self.is_disconnected = true;
self.closed.store(true, Ordering::Relaxed);
}
async fn init_web_transport(url: &str, options: WebTransportOptions) -> Result<WebTransport, WebTransportError> {
let web_transport = WebTransport::new_with_options(url, &options)?;
JsFuture::from(web_transport.ready()).await?;
Ok(web_transport)
}
fn get_close_callback(sender: async_channel::Sender<()>) -> Closure<dyn FnMut(JsValue)> {
Closure::new(move |_| {
let _ = sender.try_send(());
})
}
fn reader_task(reader: ReadableStreamDefaultReader, reader_closed: Arc<AtomicBool>, incoming_sender: async_channel::Sender<Vec<u8>>) {
spawn_local(async move {
loop {
if reader_closed.load(Ordering::Relaxed) {
break;
}
let Ok(incoming) = JsFuture::from(reader.read()).await else { break };
let result: ReadableStreamDefaultReadResult = incoming.into();
if result.is_done() {
break;
}
let data: Uint8Array = result.value().into();
if data.length() as usize > NETCODE_MAX_PACKET_BYTES {
error!("received packet that is too large from the webtransport server {}", data.length());
break;
}
let Ok(()) = incoming_sender.try_send(data.to_vec()) else { break };
}
handle_promise(reader.cancel());
});
}
}
impl Drop for WebTransportClient {
fn drop(&mut self) {
self.disconnect();
}
}
fn handle_promise(promise: Promise) {
type OptionalCallback = Option<SendWrapper<Closure<dyn FnMut(JsValue)>>>;
static mut GET_NOTHING_CALLBACK_HANDLE: OptionalCallback = None;
let nothing_callback_handle = unsafe {
#[allow(static_mut_refs)]
if GET_NOTHING_CALLBACK_HANDLE.is_none() {
let cached_callback = Closure::new(|_| {});
GET_NOTHING_CALLBACK_HANDLE = Some(SendWrapper::new(cached_callback));
}
#[allow(static_mut_refs)]
GET_NOTHING_CALLBACK_HANDLE.as_deref().unwrap()
};
let _ = promise.catch(nothing_callback_handle);
}
impl std::fmt::Debug for WebTransportClient {
fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Ok(())
}
}
impl ClientSocket for WebTransportClient {
fn is_encrypted(&self) -> bool {
true
}
fn is_reliable(&self) -> bool {
false
}
fn addr(&self) -> std::io::Result<SocketAddr> {
Err(std::io::Error::from(ErrorKind::AddrNotAvailable))
}
fn is_closed(&mut self) -> bool {
self.is_disconnected()
}
fn close(&mut self) {
self.disconnect()
}
fn preupdate(&mut self) {
if !self.is_disconnected && self.closed.load(Ordering::Relaxed) {
self.disconnect();
}
if self.writer.is_none() {
if let Ok(writer) = self.writer_receiver.try_recv() {
self.writer = Some(writer);
}
}
}
fn try_recv(&mut self, buffer: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
if self.is_closed() {
return Err(std::io::Error::from(ErrorKind::ConnectionAborted));
}
let Ok(packet) = self.incoming_receiver.try_recv() else {
return Err(std::io::Error::from(ErrorKind::WouldBlock));
};
if packet.len() > buffer.len() {
return Err(std::io::Error::from(ErrorKind::InvalidData));
}
buffer[..packet.len()].copy_from_slice(&packet[..]);
Ok((packet.len(), self.server_address()))
}
fn postupdate(&mut self) {}
fn send(&mut self, addr: SocketAddr, packet: &[u8]) -> Result<(), NetcodeTransportError> {
if self.is_closed() {
return Err(std::io::Error::from(ErrorKind::ConnectionAborted).into());
}
if addr != self.server_address() {
error!("tried sending packet to invalid WebTransport server {}", addr);
self.close();
return Err(std::io::Error::from(ErrorKind::AddrNotAvailable).into());
}
if !self.sent_connection_request {
let packet_type = renetcode2::Packet::packet_type_from_buffer(packet)?;
if packet_type != renetcode2::PacketType::ConnectionRequest {
debug!(
"ignoring {:?}, the first packet sent to a webtransport client must be a connection request",
packet_type
);
return Ok(());
}
let mut data = Vec::default();
data.extend_from_slice(packet);
let _ = self.connect_req_sender.try_send(data);
self.sent_connection_request = true;
return Ok(());
}
let Some(writer) = self.writer.as_ref().map(Fragile::get) else {
return Ok(());
};
let net_packet = Uint8Array::new_with_length(packet.len() as u32);
net_packet.copy_from(packet);
handle_promise(writer.write_with_chunk(&net_packet.into()));
Ok(())
}
}