stream-tungstenite 0.6.1

A streaming implementation of the Tungstenite WebSocket protocol
Documentation
//! Chained handshaker - executes multiple handshakers in sequence.

use async_trait::async_trait;
use std::time::Duration;

use super::traits::{BoxHandshaker, HandshakeReceiver, HandshakeSender, Handshaker};
use crate::context::ConnectionContext;
use crate::error::HandshakeError;

/// Chained handshaker - executes multiple handshakers in sequence
///
/// # Example
///
/// ```rust,ignore
/// use stream_tungstenite::handshake::{ChainedHandshaker, AuthHandshaker, SubscribeHandshaker};
///
/// let handshaker = ChainedHandshaker::new()
///     .then(AuthHandshaker::new("my-token"))
///     .then(SubscribeHandshaker::new(vec!["channel1", "channel2"]));
/// ```
pub struct ChainedHandshaker {
    /// Chain of handshakers to execute
    handshakers: Vec<BoxHandshaker>,
    /// Overall timeout for all handshakers
    timeout: Option<Duration>,
}

impl ChainedHandshaker {
    /// Create a new empty chained handshaker
    #[must_use]
    pub fn new() -> Self {
        Self {
            handshakers: Vec::new(),
            timeout: Some(Duration::from_secs(60)),
        }
    }

    /// Add a handshaker to the chain
    #[must_use]
    pub fn then<H: Handshaker + 'static>(mut self, handshaker: H) -> Self {
        self.handshakers.push(Box::new(handshaker));
        self
    }

    /// Add a boxed handshaker to the chain
    #[must_use]
    pub fn then_boxed(mut self, handshaker: BoxHandshaker) -> Self {
        self.handshakers.push(handshaker);
        self
    }

    /// Set the overall timeout
    #[must_use]
    pub const fn with_timeout(mut self, timeout: Duration) -> Self {
        self.timeout = Some(timeout);
        self
    }

    /// Disable timeout
    #[must_use]
    pub const fn without_timeout(mut self) -> Self {
        self.timeout = None;
        self
    }

    /// Get the number of handshakers in the chain
    #[must_use]
    pub fn len(&self) -> usize {
        self.handshakers.len()
    }

    /// Check if the chain is empty
    #[must_use]
    pub fn is_empty(&self) -> bool {
        self.handshakers.is_empty()
    }
}

impl Default for ChainedHandshaker {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl Handshaker for ChainedHandshaker {
    async fn handshake(
        &self,
        sender: &mut dyn HandshakeSender,
        receiver: &mut dyn HandshakeReceiver,
        context: &ConnectionContext,
    ) -> Result<(), HandshakeError> {
        for (i, handshaker) in self.handshakers.iter().enumerate() {
            tracing::debug!(
                step = i + 1,
                total = self.handshakers.len(),
                handshaker = handshaker.name(),
                "Executing handshake step"
            );

            handshaker
                .handshake(sender, receiver, context)
                .await
                .map_err(|e| {
                    tracing::warn!(
                        step = i + 1,
                        handshaker = handshaker.name(),
                        error = ?e,
                        "Handshake step failed"
                    );
                    e
                })?;
        }

        Ok(())
    }

    fn is_retryable(&self, error: &HandshakeError) -> bool {
        error.is_retryable()
    }

    fn name(&self) -> &'static str {
        "chained"
    }

    fn timeout(&self) -> Option<Duration> {
        self.timeout
    }
}

/// Convenience macro for creating chained handshakers
#[macro_export]
macro_rules! chain_handshakers {
    ($($handshaker:expr),+ $(,)?) => {
        $crate::handshake::ChainedHandshaker::new()
            $(.then($handshaker))+
    };
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::handshake::NoOpHandshaker;

    #[test]
    fn test_chained_handshaker_creation() {
        let chained = ChainedHandshaker::new()
            .then(NoOpHandshaker)
            .then(NoOpHandshaker);

        assert_eq!(chained.len(), 2);
        assert!(!chained.is_empty());
    }

    #[test]
    fn test_empty_chain() {
        let chained = ChainedHandshaker::new();
        assert!(chained.is_empty());
        assert_eq!(chained.len(), 0);
    }
}