subduction_websocket 0.8.0

WebSocket transport layer for the Subduction sync protocol
Documentation
//! # Subduction [`WebSocket`] client for Tokio

use alloc::vec::Vec;

use crate::{
    DEFAULT_MAX_MESSAGE_SIZE,
    error::{DisconnectionError, RecvError, RunError, SendError},
    handshake::{WebSocketHandshake, WebSocketHandshakeError},
    websocket::{ListenerTask, SenderTask, WebSocket},
};
use async_tungstenite::tokio::{ConnectStream, connect_async_with_config};
use future_form::{FutureForm, Sendable};
use futures::{FutureExt, future::BoxFuture};

use subduction_core::{
    authenticated::Authenticated,
    connection::{Connection, Reconnect, message::SyncMessage},
    handshake::{self, AuthenticateError, audience::Audience},
    timestamp::TimestampSeconds,
    transport::Transport,
};
use subduction_crypto::{nonce::Nonce, signer::Signer};
use tungstenite::{http::Uri, protocol::WebSocketConfig};

/// Error type for client connection.
#[derive(Debug, thiserror::Error)]
pub enum ClientConnectError {
    /// WebSocket connection error.
    #[error("WebSocket error: {0}")]
    WebSocket(#[from] tungstenite::Error),

    /// Handshake failed.
    #[error("handshake error: {0}")]
    Handshake(#[from] AuthenticateError<WebSocketHandshakeError>),
}

/// A Tokio-flavoured [`WebSocket`] client implementation.
#[derive(Debug, Clone)]
pub struct TokioWebSocketClient<R: Signer<Sendable> + Clone> {
    address: Uri,
    signer: R,
    audience: Audience,
    socket: WebSocket<ConnectStream, Sendable>,
}

impl<R: Signer<Sendable> + Clone + Send + Sync> TokioWebSocketClient<R> {
    /// Create a new [`TokioWebSocketClient`] connection.
    ///
    /// Performs the handshake protocol to authenticate both sides.
    ///
    /// # Arguments
    ///
    /// * `address` - The WebSocket URI to connect to
    /// * `signer` - The client's signer for authentication
    /// * `audience` - The expected server identity ([`Audience::Known`] for a specific peer,
    ///   [`Audience::Discover`] for service discovery)
    ///
    /// # Returns
    ///
    /// A tuple of:
    /// - The authenticated client instance (wrapped in [`Authenticated`] to prove handshake completed)
    /// - A future for the listener task (receives incoming messages)
    /// - A future for the sender task (sends outgoing messages)
    ///
    /// Both futures should be spawned as background tasks.
    ///
    /// # Errors
    ///
    /// Returns an error if the connection could not be established or handshake fails.
    ///
    /// # Panics
    ///
    /// Panics if internal state is inconsistent after a successful handshake (should never happen).
    #[allow(clippy::expect_used)]
    pub async fn new<'a>(
        address: Uri,
        signer: R,
        audience: Audience,
    ) -> Result<
        (
            Authenticated<Self, Sendable>,
            ListenerTask<'a>,
            SenderTask<'a>,
        ),
        ClientConnectError,
    >
    where
        R: 'a,
    {
        tracing::info!("Connecting to WebSocket server at {address}");
        let mut ws_config = WebSocketConfig::default();
        ws_config.max_message_size = Some(DEFAULT_MAX_MESSAGE_SIZE);
        let (ws_stream, _resp) =
            connect_async_with_config(address.clone(), Some(ws_config)).await?;

        // Perform handshake
        let now = TimestampSeconds::now();
        let nonce = Nonce::random();

        let (authenticated, sender_fut) = handshake::initiate::<Sendable, _, _, _, _>(
            WebSocketHandshake::new(ws_stream),
            |ws_handshake, peer_id| {
                let (socket, sender_fut) = WebSocket::new(ws_handshake.into_inner(), peer_id);
                (socket, Sendable::from_future(sender_fut))
            },
            &signer,
            audience,
            now,
            nonce,
        )
        .await?;

        let server_id = authenticated.peer_id();
        tracing::info!("Handshake complete: connected to {server_id}");

        let socket = authenticated.inner().clone();
        let listener_socket = socket.clone();
        let listener = ListenerTask::new(async move { listener_socket.listen().await }.boxed());
        let sender = SenderTask::new(sender_fut);

        // Lift the Authenticated proof from WebSocket to TokioWebSocketClient
        let authenticated_client = authenticated.map(|_socket| TokioWebSocketClient {
            address,
            signer,
            audience,
            socket,
        });

        Ok((authenticated_client, listener, sender))
    }

    /// Start listening for incoming messages.
    ///
    /// # Errors
    ///
    /// Returns an error if:
    /// * the connection drops unexpectedly
    /// * a message could not be sent or received
    pub async fn listen(&self) -> Result<(), RunError> {
        self.socket.listen().await
    }
}

impl<R: Signer<Sendable> + Clone + Send + Sync> Transport<Sendable> for TokioWebSocketClient<R> {
    type SendError = SendError;
    type RecvError = RecvError;
    type DisconnectionError = DisconnectionError;

    fn disconnect(&self) -> BoxFuture<'_, Result<(), Self::DisconnectionError>> {
        async { Ok(()) }.boxed()
    }

    fn send_bytes(&self, bytes: &[u8]) -> BoxFuture<'_, Result<(), Self::SendError>> {
        tracing::debug!("client sending {} bytes", bytes.len());
        Transport::<Sendable>::send_bytes(&self.socket, bytes)
    }

    fn recv_bytes(&self) -> BoxFuture<'_, Result<Vec<u8>, Self::RecvError>> {
        let socket = self.socket.clone();
        async move {
            tracing::debug!("client waiting to receive bytes");
            Transport::<Sendable>::recv_bytes(&socket).await
        }
        .boxed()
    }
}

/// Bridging impl: delegates to [`Transport`] with encode/decode so that
/// [`Reconnect`] (which requires `Connection`) is satisfied. Downstream
/// code should prefer [`MessageTransport`] for new integrations.
///
/// [`MessageTransport`]: subduction_core::transport::message::MessageTransport
impl<R: Signer<Sendable> + Clone + Send + Sync> Connection<Sendable, SyncMessage>
    for TokioWebSocketClient<R>
{
    type SendError = SendError;
    type RecvError = RecvError;
    type DisconnectionError = DisconnectionError;

    fn disconnect(&self) -> BoxFuture<'_, Result<(), Self::DisconnectionError>> {
        Transport::<Sendable>::disconnect(self)
    }

    fn send(&self, message: &SyncMessage) -> BoxFuture<'_, Result<(), Self::SendError>> {
        let bytes = message.encode();
        let this = self.socket.clone();
        async move { Transport::<Sendable>::send_bytes(&this, &bytes).await }.boxed()
    }

    fn recv(&self) -> BoxFuture<'_, Result<SyncMessage, Self::RecvError>> {
        let socket = self.socket.clone();
        async move {
            loop {
                let bytes = Transport::<Sendable>::recv_bytes(&socket).await?;
                match SyncMessage::try_decode(&bytes) {
                    Ok(msg) => return Ok(msg),
                    Err(e) => {
                        tracing::warn!("failed to decode inbound bytes as SyncMessage: {e}");
                        // Skip non-decodable frames and keep reading
                    }
                }
            }
        }
        .boxed()
    }
}

impl<R: 'static + Signer<Sendable> + Clone + Send + Sync> Reconnect<Sendable, SyncMessage>
    for TokioWebSocketClient<R>
{
    type ReconnectionError = ClientConnectError;

    fn reconnect(&mut self) -> BoxFuture<'_, Result<(), Self::ReconnectionError>> {
        async move {
            let (authenticated, listener, sender) = TokioWebSocketClient::<R>::new(
                self.address.clone(),
                self.signer.clone(),
                self.audience,
            )
            .await?;

            // Extract the inner client from the Authenticated wrapper
            *self = authenticated.into_inner();
            tokio::spawn(async move {
                if let Err(e) = listener.await {
                    tracing::info!("WebSocket client listener disconnected after reconnect: {e:?}");
                }
            });
            tokio::spawn(async move {
                if let Err(e) = sender.await {
                    tracing::info!("WebSocket client sender disconnected after reconnect: {e:?}");
                }
            });

            Ok(())
        }
        .boxed()
    }

    fn should_retry(&self, error: &Self::ReconnectionError) -> bool {
        match error {
            // Network errors are generally retryable
            ClientConnectError::WebSocket(_) => true,

            // Handshake errors depend on the specific type
            ClientConnectError::Handshake(auth_err) => match auth_err {
                // Transport errors - check the underlying WebSocket error
                AuthenticateError::Transport(ws_err) => match ws_err {
                    WebSocketHandshakeError::WebSocket(_)
                    | WebSocketHandshakeError::ConnectionClosed => true,
                    WebSocketHandshakeError::UnexpectedMessageType(_) => false,
                },

                // Connection closed during handshake - retry
                AuthenticateError::ConnectionClosed => true,

                // Protocol violations or explicit rejection - don't retry
                AuthenticateError::Decode(_)
                | AuthenticateError::Handshake(_)
                | AuthenticateError::Rejected { .. }
                | AuthenticateError::UnexpectedMessage => false,
            },
        }
    }
}

impl<R: Signer<Sendable> + Clone + Send + Sync> PartialEq for TokioWebSocketClient<R> {
    fn eq(&self, other: &Self) -> bool {
        self.address == other.address && self.socket.peer_id() == other.socket.peer_id()
    }
}