stream-tungstenite 0.6.1

A streaming implementation of the Tungstenite WebSocket protocol
Documentation
//! Handshake trait and context definitions.

use async_trait::async_trait;
use futures_util::{Sink, SinkExt, Stream, StreamExt};
use std::time::Duration;
use tokio::time;
use tungstenite::Message;

use crate::context::ConnectionContext;
use crate::error::HandshakeError;

// No separate HandshakeContext type; use `crate::context::ConnectionContext` directly.

/// Sender trait for handshake - type-erased version
#[async_trait]
pub trait HandshakeSender: Send {
    /// Send a message
    async fn send_msg(&mut self, message: Message) -> Result<(), HandshakeError>;
}

/// Receiver trait for handshake - type-erased version
#[async_trait]
pub trait HandshakeReceiver: Send {
    /// Receive a message with optional timeout
    async fn recv_msg(&mut self) -> Result<Option<Message>, HandshakeError>;
}

// Implement for any Sink
#[async_trait]
impl<S> HandshakeSender for S
where
    S: Sink<Message, Error = tungstenite::Error> + Unpin + Send,
{
    async fn send_msg(&mut self, message: Message) -> Result<(), HandshakeError> {
        SinkExt::send(self, message)
            .await
            .map_err(HandshakeError::from)
    }
}

// Implement for any Stream
#[async_trait]
impl<R> HandshakeReceiver for R
where
    R: Stream<Item = Result<Message, tungstenite::Error>> + Unpin + Send,
{
    async fn recv_msg(&mut self) -> Result<Option<Message>, HandshakeError> {
        match StreamExt::next(self).await {
            Some(Ok(msg)) => Ok(Some(msg)),
            Some(Err(e)) => Err(HandshakeError::from(e)),
            None => Ok(None),
        }
    }
}

/// Application-level handshake trait
///
/// Implement this trait to perform custom handshake logic after
/// the WebSocket connection is established (e.g., authentication,
/// subscription, etc.)
#[async_trait]
pub trait Handshaker: Send + Sync {
    /// Perform the handshake
    async fn handshake(
        &self,
        sender: &mut dyn HandshakeSender,
        receiver: &mut dyn HandshakeReceiver,
        context: &ConnectionContext,
    ) -> Result<(), HandshakeError>;

    /// Check if a handshake error is retryable
    fn is_retryable(&self, error: &HandshakeError) -> bool {
        error.is_retryable()
    }

    /// Get the handshaker name (for logging)
    fn name(&self) -> &'static str {
        "handshaker"
    }

    /// Get handshake timeout (None for no timeout)
    fn timeout(&self) -> Option<Duration> {
        Some(Duration::from_secs(30))
    }

    /// Perform the handshake with the configured timeout.
    ///
    /// This is a convenience method that wraps [`handshake`](Self::handshake)
    /// with the timeout returned by [`timeout()`](Self::timeout). If `timeout()`
    /// returns `None`, the handshake runs without a timeout.
    ///
    /// # Errors
    ///
    /// Returns [`HandshakeError::Timeout`] if the handshake exceeds the configured
    /// timeout duration. Other errors are propagated from `handshake()`.
    async fn handshake_with_timeout(
        &self,
        sender: &mut dyn HandshakeSender,
        receiver: &mut dyn HandshakeReceiver,
        context: &ConnectionContext,
    ) -> Result<(), HandshakeError> {
        if let Some(timeout) = self.timeout() {
            match time::timeout(timeout, self.handshake(sender, receiver, context)).await {
                Ok(res) => res,
                Err(_elapsed) => Err(HandshakeError::Timeout(timeout)),
            }
        } else {
            self.handshake(sender, receiver, context).await
        }
    }
}

/// Box wrapper for Handshaker trait
pub type BoxHandshaker = Box<dyn Handshaker>;

/// Create a boxed handshaker
pub fn boxed<H: Handshaker + 'static>(h: H) -> BoxHandshaker {
    Box::new(h)
}