stream-tungstenite 0.6.1

A streaming implementation of the Tungstenite WebSocket protocol
Documentation
//! Common handshaker implementations.

use async_trait::async_trait;
use std::time::Duration;
use tungstenite::Message;

use super::traits::{HandshakeReceiver, HandshakeSender, Handshaker};
use crate::context::ConnectionContext;
use crate::error::HandshakeError;

/// Send a single message handshaker
///
/// Sends a predefined message and optionally waits for a response.
pub struct SendMessageHandshaker {
    /// Message to send
    message: Message,
    /// Whether to wait for a response
    wait_response: bool,
    /// Timeout for waiting response
    timeout: Duration,
}

impl SendMessageHandshaker {
    /// Create a new send message handshaker
    #[must_use]
    pub const fn new(message: Message) -> Self {
        Self {
            message,
            wait_response: false,
            timeout: Duration::from_secs(10),
        }
    }

    /// Create a text message handshaker
    #[must_use]
    pub fn text(text: impl Into<String>) -> Self {
        Self::new(Message::Text(text.into().into()))
    }

    /// Create a binary message handshaker
    #[must_use]
    pub fn binary(data: impl Into<Vec<u8>>) -> Self {
        Self::new(Message::Binary(data.into().into()))
    }

    /// Wait for a response after sending
    #[must_use]
    pub const fn with_response(mut self) -> Self {
        self.wait_response = true;
        self
    }

    /// Set response timeout
    #[must_use]
    pub const fn with_timeout(mut self, timeout: Duration) -> Self {
        self.timeout = timeout;
        self
    }
}

#[async_trait]
impl Handshaker for SendMessageHandshaker {
    async fn handshake(
        &self,
        sender: &mut dyn HandshakeSender,
        receiver: &mut dyn HandshakeReceiver,
        _context: &ConnectionContext,
    ) -> Result<(), HandshakeError> {
        // Send the message
        sender.send_msg(self.message.clone()).await?;

        // Wait for response if required
        if self.wait_response {
            let result = tokio::time::timeout(self.timeout, receiver.recv_msg()).await;

            match result {
                Ok(Ok(Some(_msg))) => Ok(()),
                Ok(Ok(None)) => Err(HandshakeError::Failed(
                    "Connection closed during handshake".into(),
                )),
                Ok(Err(e)) => Err(e),
                Err(_) => Err(HandshakeError::Timeout(self.timeout)),
            }
        } else {
            Ok(())
        }
    }

    fn name(&self) -> &'static str {
        "send_message"
    }

    fn timeout(&self) -> Option<Duration> {
        Some(self.timeout)
    }
}

/// Authentication handshaker
///
/// Sends an authentication token and waits for confirmation.
pub struct AuthHandshaker {
    /// Authentication token or message
    token: String,
    /// Format for the auth message
    format: AuthFormat,
    /// Timeout
    timeout: Duration,
}

/// Authentication message format
#[derive(Debug, Clone)]
pub enum AuthFormat {
    /// Plain text token
    Plain,
    /// JSON format: `{"type": "auth", "token": "<token>"}`
    Json,
    /// Custom format with placeholder
    Custom(String),
}

impl AuthHandshaker {
    /// Create a new auth handshaker with plain token
    #[must_use]
    pub fn new(token: impl Into<String>) -> Self {
        Self {
            token: token.into(),
            format: AuthFormat::Plain,
            timeout: Duration::from_secs(10),
        }
    }

    /// Use JSON format
    #[must_use]
    pub fn json(mut self) -> Self {
        self.format = AuthFormat::Json;
        self
    }

    /// Use custom format (use {} as placeholder for token)
    #[must_use]
    pub fn custom_format(mut self, format: impl Into<String>) -> Self {
        self.format = AuthFormat::Custom(format.into());
        self
    }

    /// Set timeout
    #[must_use]
    pub const fn with_timeout(mut self, timeout: Duration) -> Self {
        self.timeout = timeout;
        self
    }

    fn format_message(&self) -> String {
        match &self.format {
            AuthFormat::Plain => self.token.clone(),
            AuthFormat::Json => format!(r#"{{"type":"auth","token":"{}"}}"#, self.token),
            AuthFormat::Custom(fmt) => fmt.replace("{}", &self.token),
        }
    }
}

#[async_trait]
impl Handshaker for AuthHandshaker {
    async fn handshake(
        &self,
        sender: &mut dyn HandshakeSender,
        receiver: &mut dyn HandshakeReceiver,
        _context: &ConnectionContext,
    ) -> Result<(), HandshakeError> {
        let message = self.format_message();
        tracing::debug!("Sending auth message");

        // Send auth message
        sender.send_msg(Message::Text(message.into())).await?;

        // Wait for response
        let result = tokio::time::timeout(self.timeout, receiver.recv_msg()).await;

        match result {
            Ok(Ok(Some(msg))) => {
                // Check for auth failure indicators in response
                let text = match &msg {
                    Message::Text(t) => Some(t.to_string()),
                    _ => None,
                };

                if let Some(t) = text {
                    let lower = t.to_lowercase();
                    if lower.contains("error")
                        || lower.contains("unauthorized")
                        || lower.contains("denied")
                    {
                        return Err(HandshakeError::AuthFailed(t));
                    }
                }

                tracing::debug!("Auth successful");
                Ok(())
            }
            Ok(Ok(None)) => Err(HandshakeError::Failed(
                "Connection closed during auth".into(),
            )),
            Ok(Err(e)) => Err(e),
            Err(_) => Err(HandshakeError::Timeout(self.timeout)),
        }
    }

    fn is_retryable(&self, error: &HandshakeError) -> bool {
        // Auth failures are typically not retryable
        !matches!(error, HandshakeError::AuthFailed(_))
    }

    fn name(&self) -> &'static str {
        "auth"
    }

    fn timeout(&self) -> Option<Duration> {
        Some(self.timeout)
    }
}

/// Subscribe handshaker
///
/// Sends subscription messages for multiple channels/topics.
pub struct SubscribeHandshaker {
    /// Channels to subscribe to
    channels: Vec<String>,
    /// Format for subscription messages
    format: SubscribeFormat,
    /// Timeout per subscription
    timeout: Duration,
    /// Whether to wait for confirmation
    wait_confirmation: bool,
}

/// Subscription message format
#[derive(Debug, Clone)]
pub enum SubscribeFormat {
    /// JSON format: `{"type": "subscribe", "channel": "<channel>"}`
    Json,
    /// Custom format with placeholder
    Custom(String),
}

impl SubscribeHandshaker {
    /// Create a new subscribe handshaker
    #[must_use]
    pub fn new(channels: impl IntoIterator<Item = impl Into<String>>) -> Self {
        Self {
            channels: channels.into_iter().map(Into::into).collect(),
            format: SubscribeFormat::Json,
            timeout: Duration::from_secs(5),
            wait_confirmation: false,
        }
    }

    /// Use custom format (use {} as placeholder for channel)
    #[must_use]
    pub fn custom_format(mut self, format: impl Into<String>) -> Self {
        self.format = SubscribeFormat::Custom(format.into());
        self
    }

    /// Set timeout per subscription
    #[must_use]
    pub const fn with_timeout(mut self, timeout: Duration) -> Self {
        self.timeout = timeout;
        self
    }

    /// Wait for confirmation after each subscription
    #[must_use]
    pub const fn wait_confirmation(mut self) -> Self {
        self.wait_confirmation = true;
        self
    }

    fn format_message(&self, channel: &str) -> String {
        match &self.format {
            SubscribeFormat::Json => {
                format!(r#"{{"type":"subscribe","channel":"{channel}"}}"#)
            }
            SubscribeFormat::Custom(fmt) => fmt.replace("{}", channel),
        }
    }
}

#[async_trait]
impl Handshaker for SubscribeHandshaker {
    async fn handshake(
        &self,
        sender: &mut dyn HandshakeSender,
        receiver: &mut dyn HandshakeReceiver,
        _context: &ConnectionContext,
    ) -> Result<(), HandshakeError> {
        for channel in &self.channels {
            let message = self.format_message(channel);
            tracing::debug!(channel = %channel, "Subscribing to channel");

            // Send subscription
            sender.send_msg(Message::Text(message.into())).await?;

            // Wait for confirmation if required
            if self.wait_confirmation {
                let result = tokio::time::timeout(self.timeout, receiver.recv_msg()).await;

                match result {
                    Ok(Ok(Some(_))) => {}
                    Ok(Ok(None)) => {
                        return Err(HandshakeError::Failed(format!(
                            "Connection closed while subscribing to {channel}"
                        )))
                    }
                    Ok(Err(e)) => return Err(e),
                    Err(_) => return Err(HandshakeError::Timeout(self.timeout)),
                }
            }
        }

        tracing::debug!(count = self.channels.len(), "Subscriptions complete");
        Ok(())
    }

    fn name(&self) -> &'static str {
        "subscribe"
    }

    fn timeout(&self) -> Option<Duration> {
        // Total timeout is per-channel timeout * number of channels
        // Use saturating conversion to avoid truncation issues on 64-bit systems
        let count = u32::try_from(self.channels.len()).unwrap_or(u32::MAX);
        Some(self.timeout.saturating_mul(count))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_auth_format_plain() {
        let auth = AuthHandshaker::new("my-token");
        assert_eq!(auth.format_message(), "my-token");
    }

    #[test]
    fn test_auth_format_json() {
        let auth = AuthHandshaker::new("my-token").json();
        assert_eq!(
            auth.format_message(),
            r#"{"type":"auth","token":"my-token"}"#
        );
    }

    #[test]
    fn test_auth_format_custom() {
        let auth = AuthHandshaker::new("my-token").custom_format("AUTH {}");
        assert_eq!(auth.format_message(), "AUTH my-token");
    }

    #[test]
    fn test_subscribe_format() {
        let sub = SubscribeHandshaker::new(vec!["channel1"]);
        assert_eq!(
            sub.format_message("test"),
            r#"{"type":"subscribe","channel":"test"}"#
        );
    }
}