use alloc::vec::Vec;
use crate::{
DEFAULT_MAX_MESSAGE_SIZE,
error::{DisconnectionError, RecvError, RunError, SendError},
handshake::{WebSocketHandshake, WebSocketHandshakeError},
websocket::{ListenerTask, SenderTask, WebSocket},
};
use async_tungstenite::tokio::{ConnectStream, connect_async_with_config};
use future_form::{FutureForm, Sendable};
use futures::{FutureExt, future::BoxFuture};
use subduction_core::{
authenticated::Authenticated,
connection::{Connection, Reconnect, message::SyncMessage},
handshake::{self, AuthenticateError, audience::Audience},
timestamp::TimestampSeconds,
transport::Transport,
};
use subduction_crypto::{nonce::Nonce, signer::Signer};
use tungstenite::{http::Uri, protocol::WebSocketConfig};
#[derive(Debug, thiserror::Error)]
pub enum ClientConnectError {
#[error("WebSocket error: {0}")]
WebSocket(#[from] tungstenite::Error),
#[error("handshake error: {0}")]
Handshake(#[from] AuthenticateError<WebSocketHandshakeError>),
}
#[derive(Debug, Clone)]
pub struct TokioWebSocketClient<R: Signer<Sendable> + Clone> {
address: Uri,
signer: R,
audience: Audience,
socket: WebSocket<ConnectStream, Sendable>,
}
impl<R: Signer<Sendable> + Clone + Send + Sync> TokioWebSocketClient<R> {
#[allow(clippy::expect_used)]
pub async fn new<'a>(
address: Uri,
signer: R,
audience: Audience,
) -> Result<
(
Authenticated<Self, Sendable>,
ListenerTask<'a>,
SenderTask<'a>,
),
ClientConnectError,
>
where
R: 'a,
{
tracing::info!("Connecting to WebSocket server at {address}");
let mut ws_config = WebSocketConfig::default();
ws_config.max_message_size = Some(DEFAULT_MAX_MESSAGE_SIZE);
let (ws_stream, _resp) =
connect_async_with_config(address.clone(), Some(ws_config)).await?;
let now = TimestampSeconds::now();
let nonce = Nonce::random();
let (authenticated, sender_fut) = handshake::initiate::<Sendable, _, _, _, _>(
WebSocketHandshake::new(ws_stream),
|ws_handshake, peer_id| {
let (socket, sender_fut) = WebSocket::new(ws_handshake.into_inner(), peer_id);
(socket, Sendable::from_future(sender_fut))
},
&signer,
audience,
now,
nonce,
)
.await?;
let server_id = authenticated.peer_id();
tracing::info!("Handshake complete: connected to {server_id}");
let socket = authenticated.inner().clone();
let listener_socket = socket.clone();
let listener = ListenerTask::new(async move { listener_socket.listen().await }.boxed());
let sender = SenderTask::new(sender_fut);
let authenticated_client = authenticated.map(|_socket| TokioWebSocketClient {
address,
signer,
audience,
socket,
});
Ok((authenticated_client, listener, sender))
}
pub async fn listen(&self) -> Result<(), RunError> {
self.socket.listen().await
}
}
impl<R: Signer<Sendable> + Clone + Send + Sync> Transport<Sendable> for TokioWebSocketClient<R> {
type SendError = SendError;
type RecvError = RecvError;
type DisconnectionError = DisconnectionError;
fn disconnect(&self) -> BoxFuture<'_, Result<(), Self::DisconnectionError>> {
async { Ok(()) }.boxed()
}
fn send_bytes(&self, bytes: &[u8]) -> BoxFuture<'_, Result<(), Self::SendError>> {
tracing::debug!("client sending {} bytes", bytes.len());
Transport::<Sendable>::send_bytes(&self.socket, bytes)
}
fn recv_bytes(&self) -> BoxFuture<'_, Result<Vec<u8>, Self::RecvError>> {
let socket = self.socket.clone();
async move {
tracing::debug!("client waiting to receive bytes");
Transport::<Sendable>::recv_bytes(&socket).await
}
.boxed()
}
}
impl<R: Signer<Sendable> + Clone + Send + Sync> Connection<Sendable, SyncMessage>
for TokioWebSocketClient<R>
{
type SendError = SendError;
type RecvError = RecvError;
type DisconnectionError = DisconnectionError;
fn disconnect(&self) -> BoxFuture<'_, Result<(), Self::DisconnectionError>> {
Transport::<Sendable>::disconnect(self)
}
fn send(&self, message: &SyncMessage) -> BoxFuture<'_, Result<(), Self::SendError>> {
let bytes = message.encode();
let this = self.socket.clone();
async move { Transport::<Sendable>::send_bytes(&this, &bytes).await }.boxed()
}
fn recv(&self) -> BoxFuture<'_, Result<SyncMessage, Self::RecvError>> {
let socket = self.socket.clone();
async move {
loop {
let bytes = Transport::<Sendable>::recv_bytes(&socket).await?;
match SyncMessage::try_decode(&bytes) {
Ok(msg) => return Ok(msg),
Err(e) => {
tracing::warn!("failed to decode inbound bytes as SyncMessage: {e}");
}
}
}
}
.boxed()
}
}
impl<R: 'static + Signer<Sendable> + Clone + Send + Sync> Reconnect<Sendable, SyncMessage>
for TokioWebSocketClient<R>
{
type ReconnectionError = ClientConnectError;
fn reconnect(&mut self) -> BoxFuture<'_, Result<(), Self::ReconnectionError>> {
async move {
let (authenticated, listener, sender) = TokioWebSocketClient::<R>::new(
self.address.clone(),
self.signer.clone(),
self.audience,
)
.await?;
*self = authenticated.into_inner();
tokio::spawn(async move {
if let Err(e) = listener.await {
tracing::info!("WebSocket client listener disconnected after reconnect: {e:?}");
}
});
tokio::spawn(async move {
if let Err(e) = sender.await {
tracing::info!("WebSocket client sender disconnected after reconnect: {e:?}");
}
});
Ok(())
}
.boxed()
}
fn should_retry(&self, error: &Self::ReconnectionError) -> bool {
match error {
ClientConnectError::WebSocket(_) => true,
ClientConnectError::Handshake(auth_err) => match auth_err {
AuthenticateError::Transport(ws_err) => match ws_err {
WebSocketHandshakeError::WebSocket(_)
| WebSocketHandshakeError::ConnectionClosed => true,
WebSocketHandshakeError::UnexpectedMessageType(_) => false,
},
AuthenticateError::ConnectionClosed => true,
AuthenticateError::Decode(_)
| AuthenticateError::Handshake(_)
| AuthenticateError::Rejected { .. }
| AuthenticateError::UnexpectedMessage => false,
},
}
}
}
impl<R: Signer<Sendable> + Clone + Send + Sync> PartialEq for TokioWebSocketClient<R> {
fn eq(&self, other: &Self) -> bool {
self.address == other.address && self.socket.peer_id() == other.socket.peer_id()
}
}