use crate::runtime::codec::PROTOCOL_VERSION;
use crate::runtime::state::{ConnectionState, StateMachine};
use crate::runtime::transport::{Transport, TransportConfig, TransportError};
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
pub struct WebTransportClient {
config: TransportConfig,
state: Arc<RwLock<StateMachine>>,
outgoing_tx: Option<mpsc::Sender<Vec<u8>>>,
incoming_rx: Option<mpsc::Receiver<Vec<u8>>>,
#[cfg(test)]
test_server_hash: Option<wtransport::tls::Sha256Digest>,
task_handle: Option<tokio::task::JoinHandle<()>>,
}
impl WebTransportClient {
pub fn new(config: TransportConfig) -> Self {
Self {
config: config.clone(),
state: Arc::new(RwLock::new(StateMachine::new(config.retry))),
outgoing_tx: None,
incoming_rx: None,
#[cfg(test)]
test_server_hash: None,
task_handle: None,
}
}
#[cfg(test)]
fn new_with_test_hash(
config: TransportConfig,
test_server_hash: wtransport::tls::Sha256Digest,
) -> Self {
Self {
config: config.clone(),
state: Arc::new(RwLock::new(StateMachine::new(config.retry))),
outgoing_tx: None,
incoming_rx: None,
test_server_hash: Some(test_server_hash),
task_handle: None,
}
}
pub async fn state(&self) -> ConnectionState {
self.state.read().await.state()
}
pub async fn is_connected(&self) -> bool {
self.state.read().await.state().is_connected()
}
pub async fn connect(&mut self) -> Result<(), TransportError> {
let url: url::Url = self
.config
.url
.parse()
.map_err(|e| TransportError::ConnectionFailed(format!("invalid URL: {}", e)))?;
let scheme = url.scheme();
if scheme != "https" {
return Err(TransportError::ConnectionFailed(format!(
"unsupported URL scheme '{}': WebTransport requires 'https'",
scheme
)));
}
{
let mut state = self.state.write().await;
state
.start_connecting()
.map_err(|_| TransportError::InvalidState)?;
}
#[cfg(test)]
let client_config = {
let builder = wtransport::ClientConfig::builder().with_bind_default();
if let Some(hash) = &self.test_server_hash {
builder
.with_server_certificate_hashes([hash.clone()])
.build()
} else {
builder.with_native_certs().build()
}
};
#[cfg(not(test))]
let client_config = wtransport::ClientConfig::default();
let endpoint = wtransport::Endpoint::client(client_config)
.map_err(|e| TransportError::ConnectionFailed(format!("endpoint error: {}", e)))?;
let timeout = tokio::time::Duration::from_millis(self.config.connect_timeout_ms);
let connection = tokio::time::timeout(timeout, endpoint.connect(&self.config.url))
.await
.map_err(|_| TransportError::Timeout)?
.map_err(|e| TransportError::ConnectionFailed(format!("connect error: {}", e)))?;
let (outgoing_tx, outgoing_rx) = mpsc::channel::<Vec<u8>>(256);
let (incoming_tx, incoming_rx) = mpsc::channel::<Vec<u8>>(256);
self.outgoing_tx = Some(outgoing_tx);
self.incoming_rx = Some(incoming_rx);
{
let mut state = self.state.write().await;
state
.connected()
.map_err(|_| TransportError::InvalidState)?;
}
let state = Arc::clone(&self.state);
let handle = tokio::spawn(async move {
Self::datagram_loop(connection, outgoing_rx, incoming_tx, state).await;
});
self.task_handle = Some(handle);
Ok(())
}
async fn datagram_loop(
connection: wtransport::Connection,
mut outgoing_rx: mpsc::Receiver<Vec<u8>>,
incoming_tx: mpsc::Sender<Vec<u8>>,
state: Arc<RwLock<StateMachine>>,
) {
loop {
tokio::select! {
Some(data) = outgoing_rx.recv() => {
if connection.send_datagram(data).is_err() {
break;
}
}
result = connection.receive_datagram() => {
match result {
Ok(datagram) => {
let payload = datagram.payload().to_vec();
if incoming_tx.send(payload).await.is_err() {
break; }
}
Err(_) => {
break; }
}
}
}
}
let mut s = state.write().await;
s.disconnect();
}
pub async fn disconnect(&mut self) {
{
let mut state = self.state.write().await;
state.disconnect();
}
self.outgoing_tx = None;
self.incoming_rx = None;
if let Some(handle) = self.task_handle.take() {
handle.abort();
}
}
pub async fn send(&self, data: Vec<u8>) -> Result<(), TransportError> {
if data.is_empty() || data[0] != PROTOCOL_VERSION {
return Err(TransportError::InvalidPacket);
}
if data.len() > self.config.max_message_size {
return Err(TransportError::PacketTooLarge {
size: data.len(),
max: self.config.max_message_size,
});
}
let tx = self
.outgoing_tx
.as_ref()
.ok_or(TransportError::NotConnected)?;
tx.send(data)
.await
.map_err(|_| TransportError::SendFailed)?;
Ok(())
}
pub async fn receive(&mut self) -> Result<Vec<u8>, TransportError> {
let rx = self
.incoming_rx
.as_mut()
.ok_or(TransportError::NotConnected)?;
rx.recv().await.ok_or(TransportError::ConnectionClosed)
}
pub fn try_receive(&mut self) -> Result<Option<Vec<u8>>, TransportError> {
let rx = self
.incoming_rx
.as_mut()
.ok_or(TransportError::NotConnected)?;
match rx.try_recv() {
Ok(data) => Ok(Some(data)),
Err(mpsc::error::TryRecvError::Empty) => Ok(None),
Err(mpsc::error::TryRecvError::Disconnected) => Err(TransportError::ConnectionClosed),
}
}
pub fn url(&self) -> &str {
&self.config.url
}
}
impl Transport for WebTransportClient {
async fn connect(&mut self) -> Result<(), TransportError> {
self.connect().await
}
async fn disconnect(&mut self) {
self.disconnect().await
}
async fn send(&self, data: Vec<u8>) -> Result<(), TransportError> {
self.send(data).await
}
async fn receive(&mut self) -> Result<Vec<u8>, TransportError> {
self.receive().await
}
fn try_receive(&mut self) -> Result<Option<Vec<u8>>, TransportError> {
self.try_receive()
}
async fn state(&self) -> ConnectionState {
self.state().await
}
async fn is_connected(&self) -> bool {
self.is_connected().await
}
fn url(&self) -> &str {
self.url()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::state::RetryConfig;
use pretty_assertions::assert_eq;
use std::time::Duration;
fn test_config(url: &str) -> TransportConfig {
TransportConfig {
url: url.to_string(),
retry: RetryConfig::default(),
connect_timeout_ms: 5_000,
keepalive_interval_ms: 0,
max_message_size: 65535,
}
}
#[tokio::test]
async fn test_client_initial_state() {
let client = WebTransportClient::new(test_config("https://localhost:4433"));
assert_eq!(client.state().await, ConnectionState::Disconnected);
}
#[tokio::test]
async fn test_connect_wrong_scheme() {
let mut client = WebTransportClient::new(test_config("ws://example.com"));
let result = client.connect().await;
assert!(matches!(result, Err(TransportError::ConnectionFailed(_))));
if let Err(TransportError::ConnectionFailed(msg)) = result {
assert!(msg.contains("https"), "error should mention https: {}", msg);
}
}
#[tokio::test]
async fn test_connect_invalid_url() {
let mut client = WebTransportClient::new(test_config("not-a-url"));
let result = client.connect().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_send_requires_version_byte() {
let client = WebTransportClient::new(test_config("https://localhost:4433"));
let result = client.send(vec![]).await;
assert!(matches!(result, Err(TransportError::InvalidPacket)));
let result = client.send(vec![0xFF, 0x01]).await;
assert!(matches!(result, Err(TransportError::InvalidPacket)));
}
#[tokio::test]
async fn test_send_not_connected() {
let client = WebTransportClient::new(test_config("https://localhost:4433"));
let result = client.send(vec![PROTOCOL_VERSION, 0x01]).await;
assert!(matches!(result, Err(TransportError::NotConnected)));
}
#[tokio::test]
async fn test_receive_not_connected() {
let mut client = WebTransportClient::new(test_config("https://localhost:4433"));
let result = client.receive().await;
assert!(matches!(result, Err(TransportError::NotConnected)));
}
#[tokio::test]
async fn test_try_receive_not_connected() {
let mut client = WebTransportClient::new(test_config("https://localhost:4433"));
let result = client.try_receive();
assert!(matches!(result, Err(TransportError::NotConnected)));
}
#[tokio::test]
async fn test_disconnect_is_idempotent() {
let mut client = WebTransportClient::new(test_config("https://localhost:4433"));
client.disconnect().await;
assert_eq!(client.state().await, ConnectionState::Disconnected);
client.disconnect().await;
assert_eq!(client.state().await, ConnectionState::Disconnected);
}
#[tokio::test]
async fn test_url_accessor() {
let client = WebTransportClient::new(test_config("https://localhost:4433"));
assert_eq!(client.url(), "https://localhost:4433");
}
#[tokio::test]
async fn test_packet_too_large() {
let mut config = test_config("https://localhost:4433");
config.max_message_size = 10;
let client = WebTransportClient::new(config);
let mut data = vec![PROTOCOL_VERSION];
data.extend_from_slice(&[0u8; 10]);
let result = client.send(data).await;
assert!(matches!(result, Err(TransportError::PacketTooLarge { .. })));
}
#[tokio::test]
async fn test_connect_to_inprocess_server() {
let identity = wtransport::Identity::self_signed(["localhost", "127.0.0.1", "::1"])
.expect("failed to build self-signed identity for test server");
let cert_hash = identity.certificate_chain().as_slice()[0].hash();
let server_config = wtransport::ServerConfig::builder()
.with_bind_default(0)
.with_identity(identity)
.build();
let server = wtransport::Endpoint::server(server_config)
.expect("failed to start in-process WebTransport server");
let server_port = server
.local_addr()
.expect("failed to read server local addr")
.port();
let server_task = tokio::spawn(async move {
let incoming = server.accept().await;
if let Ok(request) = incoming.await {
if let Ok(connection) = request.accept().await {
if let Ok(datagram) = connection.receive_datagram().await {
let _ = connection.send_datagram(datagram.payload());
tokio::time::sleep(Duration::from_millis(200)).await;
}
}
}
});
let mut client = WebTransportClient::new_with_test_hash(
test_config(&format!("https://localhost:{}", server_port)),
cert_hash,
);
client
.connect()
.await
.expect("client should connect to in-process server");
assert!(client.is_connected().await);
let packet = vec![PROTOCOL_VERSION, 0xAA, 0xBB, 0xCC];
client
.send(packet.clone())
.await
.expect("send should succeed");
let echoed = tokio::time::timeout(Duration::from_secs(5), client.receive())
.await
.expect("timed out waiting for echoed datagram")
.expect("receive should succeed");
assert_eq!(echoed, packet);
client.disconnect().await;
assert_eq!(client.state().await, ConnectionState::Disconnected);
server_task.abort();
}
}