use crate::{
DEFAULT_MAX_MESSAGE_SIZE,
error::{CallError, DisconnectionError, RecvError, RunError, SendError},
handshake::{WebSocketHandshake, WebSocketHandshakeError},
timeout::Timeout,
websocket::{ListenerTask, SenderTask, WebSocket},
};
use async_tungstenite::tokio::{ConnectStream, connect_async_with_config};
use core::time::Duration;
use future_form::{FutureForm, Sendable};
use futures::{FutureExt, future::BoxFuture};
use subduction_core::{
connection::{
Connection, Reconnect,
authenticated::Authenticated,
handshake::{self, Audience, AuthenticateError},
message::{BatchSyncRequest, BatchSyncResponse, Message, RequestId},
},
peer::id::PeerId,
timestamp::TimestampSeconds,
};
use subduction_crypto::{nonce::Nonce, signer::Signer};
use tungstenite::{http::Uri, protocol::WebSocketConfig};
#[derive(Debug, thiserror::Error)]
pub enum ClientConnectError {
#[error("WebSocket error: {0}")]
WebSocket(#[from] tungstenite::Error),
#[error("handshake error: {0}")]
Handshake(#[from] AuthenticateError<WebSocketHandshakeError>),
}
#[derive(Debug, Clone)]
pub struct TokioWebSocketClient<R: Signer<Sendable> + Clone, O: Timeout<Sendable> + Send + Sync> {
address: Uri,
signer: R,
audience: Audience,
socket: WebSocket<ConnectStream, Sendable, O>,
}
impl<R: Signer<Sendable> + Clone + Send + Sync, O: Timeout<Sendable> + Send + Sync>
TokioWebSocketClient<R, O>
{
#[allow(clippy::expect_used)]
pub async fn new<'a>(
address: Uri,
timeout: O,
default_time_limit: Duration,
signer: R,
audience: Audience,
) -> Result<
(
Authenticated<Self, Sendable>,
ListenerTask<'a>,
SenderTask<'a>,
),
ClientConnectError,
>
where
O: 'a,
R: 'a,
{
tracing::info!("Connecting to WebSocket server at {address}");
let mut ws_config = WebSocketConfig::default();
ws_config.max_message_size = Some(DEFAULT_MAX_MESSAGE_SIZE);
let (ws_stream, _resp) =
connect_async_with_config(address.clone(), Some(ws_config)).await?;
let now = TimestampSeconds::now();
let nonce = Nonce::random();
let timeout_clone = timeout.clone();
let (authenticated, sender_fut) = handshake::initiate::<Sendable, _, _, _, _>(
WebSocketHandshake::new(ws_stream),
|ws_handshake, peer_id| {
let (socket, sender_fut) = WebSocket::<_, _, O>::new(
ws_handshake.into_inner(),
timeout_clone,
default_time_limit,
peer_id,
);
(socket, Sendable::from_future(sender_fut))
},
&signer,
audience,
now,
nonce,
)
.await?;
let server_id = authenticated.peer_id();
tracing::info!("Handshake complete: connected to {server_id}");
let socket = authenticated.inner().clone();
let listener_socket = socket.clone();
let listener = ListenerTask::new(async move { listener_socket.listen().await }.boxed());
let sender = SenderTask::new(sender_fut);
let authenticated_client = authenticated.map(|_socket| TokioWebSocketClient {
address,
signer,
audience,
socket,
});
Ok((authenticated_client, listener, sender))
}
pub async fn listen(&self) -> Result<(), RunError> {
self.socket.listen().await
}
}
impl<R: Signer<Sendable> + Clone + Send + Sync, O: Timeout<Sendable> + Send + Sync>
Connection<Sendable> for TokioWebSocketClient<R, O>
{
type SendError = SendError;
type RecvError = RecvError;
type CallError = CallError;
type DisconnectionError = DisconnectionError;
fn peer_id(&self) -> PeerId {
Connection::<Sendable>::peer_id(&self.socket)
}
fn next_request_id(&self) -> BoxFuture<'_, RequestId> {
async { Connection::<Sendable>::next_request_id(&self.socket).await }.boxed()
}
fn disconnect(&self) -> BoxFuture<'_, Result<(), Self::DisconnectionError>> {
async { Ok(()) }.boxed()
}
fn send(&self, message: &Message) -> BoxFuture<'_, Result<(), Self::SendError>> {
tracing::debug!("client sending message: {:?}", message);
Connection::<Sendable>::send(&self.socket, message)
}
fn recv(&self) -> BoxFuture<'_, Result<Message, Self::RecvError>> {
async {
tracing::debug!("client waiting to receive message");
Connection::<Sendable>::recv(&self.socket).await
}
.boxed()
}
fn call(
&self,
req: BatchSyncRequest,
override_timeout: Option<Duration>,
) -> BoxFuture<'_, Result<BatchSyncResponse, Self::CallError>> {
async move {
tracing::debug!("client making call with request: {:?}", req);
Connection::<Sendable>::call(&self.socket, req, override_timeout).await
}
.boxed()
}
}
impl<
R: 'static + Signer<Sendable> + Clone + Send + Sync,
O: 'static + Timeout<Sendable> + Send + Sync,
> Reconnect<Sendable> for TokioWebSocketClient<R, O>
{
type ReconnectionError = ClientConnectError;
fn reconnect(&mut self) -> BoxFuture<'_, Result<(), Self::ReconnectionError>> {
async move {
let (authenticated, listener, sender) = TokioWebSocketClient::new(
self.address.clone(),
self.socket.timeout_strategy().clone(),
self.socket.default_time_limit(),
self.signer.clone(),
self.audience,
)
.await?;
*self = authenticated.into_inner();
tokio::spawn(async move {
if let Err(e) = listener.await {
tracing::info!("WebSocket client listener disconnected after reconnect: {e:?}");
}
});
tokio::spawn(async move {
if let Err(e) = sender.await {
tracing::info!("WebSocket client sender disconnected after reconnect: {e:?}");
}
});
Ok(())
}
.boxed()
}
fn should_retry(&self, error: &Self::ReconnectionError) -> bool {
match error {
ClientConnectError::WebSocket(_) => true,
ClientConnectError::Handshake(auth_err) => match auth_err {
AuthenticateError::Transport(ws_err) => match ws_err {
WebSocketHandshakeError::WebSocket(_)
| WebSocketHandshakeError::ConnectionClosed => true,
WebSocketHandshakeError::UnexpectedMessageType(_) => false,
},
AuthenticateError::ConnectionClosed => true,
AuthenticateError::Decode(_)
| AuthenticateError::Handshake(_)
| AuthenticateError::Rejected { .. }
| AuthenticateError::UnexpectedMessage => false,
},
}
}
}
impl<R: Signer<Sendable> + Clone + Send + Sync, O: Timeout<Sendable> + Send + Sync> PartialEq
for TokioWebSocketClient<R, O>
{
fn eq(&self, other: &Self) -> bool {
self.address == other.address && self.socket.peer_id() == other.socket.peer_id()
}
}