stream-tungstenite 0.6.1

A streaming implementation of the Tungstenite WebSocket protocol
Documentation
//! Message dispatcher - handles message routing and distribution.

use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{SinkExt, StreamExt};
use std::future::Future;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::task::JoinHandle;
use tungstenite::Message;

/// Shared message type for zero-copy broadcasting
pub type SharedMessage = Arc<Message>;

use crate::connection::WsStream;
use crate::error::{ExtensionError, ReceiveError, SendError};

/// Message dispatcher configuration
#[derive(Debug, Clone)]
pub struct DispatcherConfig {
    /// Receive timeout
    pub receive_timeout: Duration,
    /// Message broadcast channel capacity
    pub broadcast_capacity: usize,
    /// Outgoing send buffer capacity used by the dispatcher internal queue
    pub send_buffer_capacity: usize,
    /// Policy for handling processor errors
    pub processor_error_policy: ProcessorErrorPolicy,
}

/// Policy for handling processor (extension) errors on receive path
#[derive(Debug, Clone, Copy)]
pub enum ProcessorErrorPolicy {
    /// Ignore the error (log and continue)
    Ignore,
    /// Treat as fatal and disconnect (bubble up an error)
    Disconnect,
}

impl Default for DispatcherConfig {
    fn default() -> Self {
        Self {
            receive_timeout: Duration::from_secs(30),
            broadcast_capacity: 1024,
            send_buffer_capacity: 256,
            processor_error_policy: ProcessorErrorPolicy::Ignore,
        }
    }
}

impl DispatcherConfig {
    /// Create a new dispatcher config
    #[must_use]
    pub fn new() -> Self {
        Self::default()
    }

    /// Set receive timeout
    #[must_use]
    pub const fn with_receive_timeout(mut self, timeout: Duration) -> Self {
        self.receive_timeout = timeout;
        self
    }

    /// Set broadcast channel capacity
    #[must_use]
    pub const fn with_broadcast_capacity(mut self, capacity: usize) -> Self {
        self.broadcast_capacity = capacity;
        self
    }

    /// Set send buffer capacity for dispatcher's internal queue
    #[must_use]
    pub const fn with_send_buffer_capacity(mut self, capacity: usize) -> Self {
        self.send_buffer_capacity = capacity;
        self
    }

    /// Set processor error handling policy
    #[must_use]
    pub const fn with_processor_error_policy(mut self, policy: ProcessorErrorPolicy) -> Self {
        self.processor_error_policy = policy;
        self
    }
}

/// Internal sender state managed by the dispatcher
struct SenderState<S: WsStream> {
    /// Background task that owns the `SplitSink` and pulls from `send_rx`
    send_task: Option<JoinHandle<()>>,
    /// Internal queue for outgoing messages
    send_tx: Option<mpsc::Sender<Message>>,
    /// Marker for the stream type
    _marker: PhantomData<S>,
}

/// Message dispatcher - handles sending and receiving messages
pub struct MessageDispatcher<S: WsStream = crate::connection::DefaultWsStream> {
    /// Configuration
    config: DispatcherConfig,
    /// Sender state (protected by `RwLock` for concurrent attach/detach)
    sender_state: Arc<RwLock<SenderState<S>>>,
    /// Fast path connection check
    is_connected: Arc<AtomicBool>,
    /// Message broadcaster (Arc-wrapped for zero-copy)
    message_tx: broadcast::Sender<SharedMessage>,
}

// Note: `future_not_send` is intentionally allowed here because the dispatcher
// is designed to be used within single-threaded async contexts where the stream
// type `S` may not implement `Sync`. The internal `RwLock` provides safe concurrent
// access within the same runtime.
#[allow(clippy::future_not_send)]
impl<S: WsStream> MessageDispatcher<S> {
    /// Create a new message dispatcher
    #[must_use]
    pub fn new(config: DispatcherConfig) -> Self {
        let (message_tx, _) = broadcast::channel(config.broadcast_capacity);

        Self {
            config,
            sender_state: Arc::new(RwLock::new(SenderState::<S> {
                send_task: None,
                send_tx: None,
                _marker: PhantomData,
            })),
            is_connected: Arc::new(AtomicBool::new(false)),
            message_tx,
        }
    }

    /// Attach a sender (called when connection is established)
    pub async fn attach(&self, sender: SplitSink<S, Message>) {
        // Create internal queue
        let (tx, mut rx) = mpsc::channel::<Message>(self.config.send_buffer_capacity);

        // Spawn background send task that owns the sink
        let connected = self.is_connected.clone();
        let send_task = tokio::spawn(async move {
            let mut sink = sender;
            while let Some(msg) = rx.recv().await {
                // On send error, mark disconnected and stop the task
                if let Err(e) = sink.send(msg).await {
                    tracing::debug!(error = ?e, "Dispatcher send task encountered error");
                    connected.store(false, Ordering::Release);
                    break;
                }
            }
        });

        // Publish state
        {
            let mut state = self.sender_state.write().await;
            // Clean up any previous task/channel if present
            if let Some(handle) = state.send_task.take() {
                handle.abort();
            }
            state.send_tx = Some(tx);
            state.send_task = Some(send_task);
        }
        // Set connected after state is visible
        self.is_connected.store(true, Ordering::Release);
        tracing::debug!("Message dispatcher attached");
    }

    /// Detach the sender (called when connection is lost)
    pub async fn detach(&self) {
        self.is_connected.store(false, Ordering::Release);
        {
            let mut state = self.sender_state.write().await;
            // Drop the channel to stop producers
            state.send_tx = None;
            // Abort background task
            if let Some(handle) = state.send_task.take() {
                handle.abort();
            }
        }
        tracing::debug!("Message dispatcher detached");
    }

    /// Check if connected (fast path, no lock)
    #[must_use]
    pub fn is_connected(&self) -> bool {
        self.is_connected.load(Ordering::Acquire)
    }

    /// Send a message
    ///
    /// # Errors
    ///
    /// - Returns [`SendError::NotConnected`] if not currently connected.
    /// - Returns [`SendError::ChannelClosed`] if the internal send queue is closed.
    pub async fn send(&self, msg: Message) -> Result<(), SendError> {
        // Fast path
        if !self.is_connected() {
            return Err(SendError::NotConnected);
        }
        // Clone tx without holding the lock across await
        let tx = {
            let state = self.sender_state.read().await;
            state.send_tx.clone()
        };
        match tx {
            Some(tx) => tx.send(msg).await.map_err(|_| SendError::ChannelClosed),
            None => Err(SendError::NotConnected),
        }
    }

    /// Subscribe to messages
    ///
    /// Returns a receiver for shared messages. Messages are wrapped in `Arc<Message>`
    /// for zero-copy broadcasting. To get owned `Message`:
    /// - Read-only access: `msg.as_ref()`
    /// - Need ownership: `Arc::try_unwrap(msg).unwrap_or_else(|arc| (*arc).clone())`
    #[must_use]
    pub fn subscribe(&self) -> broadcast::Receiver<SharedMessage> {
        self.message_tx.subscribe()
    }

    /// Get the number of message subscribers
    #[must_use]
    pub fn subscriber_count(&self) -> usize {
        self.message_tx.receiver_count()
    }

    /// Run the receive loop
    ///
    /// This consumes messages from the receiver and broadcasts them to subscribers.
    /// Returns when the connection is closed or an error occurs.
    ///
    /// # Errors
    ///
    /// - Returns [`ReceiveError::WebSocket`] if a WebSocket error occurs.
    /// - Returns [`ReceiveError::StreamClosed`] if the stream is closed.
    /// - Returns [`ReceiveError::Timeout`] if no message is received within the configured timeout.
    pub async fn receive_loop(&self, mut receiver: SplitStream<S>) -> Result<(), ReceiveError> {
        let timeout = self.config.receive_timeout;

        loop {
            let result = tokio::time::timeout(timeout, receiver.next()).await;

            match result {
                Ok(Some(Ok(msg))) => {
                    // Broadcast message to all subscribers (zero-copy with Arc)
                    // Ignore send errors (no subscribers)
                    let _ = self.message_tx.send(Arc::new(msg));
                }
                Ok(Some(Err(e))) => {
                    tracing::debug!(error = ?e, "WebSocket receive error");
                    return Err(ReceiveError::WebSocket(e.to_string()));
                }
                Ok(None) => {
                    tracing::debug!("WebSocket stream closed");
                    return Err(ReceiveError::StreamClosed);
                }
                Err(_) => {
                    tracing::debug!(timeout = ?timeout, "Receive timeout");
                    return Err(ReceiveError::Timeout(timeout));
                }
            }
        }
    }

    /// Receive loop with activity callback
    ///
    /// Calls the provided callback on each received message for activity tracking.
    ///
    /// # Errors
    ///
    /// - Returns [`ReceiveError::WebSocket`] if a WebSocket error occurs.
    /// - Returns [`ReceiveError::StreamClosed`] if the stream is closed.
    /// - Returns [`ReceiveError::Timeout`] if no message is received within the configured timeout.
    pub async fn receive_loop_with_activity<F>(
        &self,
        mut receiver: SplitStream<S>,
        on_activity: F,
    ) -> Result<(), ReceiveError>
    where
        F: Fn() + Send + Sync,
    {
        let timeout = self.config.receive_timeout;

        loop {
            let result = tokio::time::timeout(timeout, receiver.next()).await;

            match result {
                Ok(Some(Ok(msg))) => {
                    // Notify activity
                    on_activity();

                    // Broadcast message (zero-copy with Arc)
                    let _ = self.message_tx.send(Arc::new(msg));
                }
                Ok(Some(Err(e))) => {
                    return Err(ReceiveError::WebSocket(e.to_string()));
                }
                Ok(None) => {
                    return Err(ReceiveError::StreamClosed);
                }
                Err(_) => {
                    return Err(ReceiveError::Timeout(timeout));
                }
            }
        }
    }

    /// Receive loop with async activity callback and async processor
    ///
    /// The processor can transform or filter messages. Returning Ok(Some(msg)) broadcasts it,
    /// Ok(None) drops it, Err(_) logs and continues.
    ///
    /// # Errors
    ///
    /// - Returns [`ReceiveError::WebSocket`] if a WebSocket error occurs.
    /// - Returns [`ReceiveError::StreamClosed`] if the stream is closed.
    /// - Returns [`ReceiveError::Timeout`] if no message is received within the configured timeout.
    pub async fn receive_loop_with_processor<FAct, FActFut, FProc, FProcFut>(
        &self,
        mut receiver: SplitStream<S>,
        on_activity: FAct,
        processor: FProc,
    ) -> Result<(), ReceiveError>
    where
        FAct: Fn() -> FActFut + Send + Sync,
        FActFut: Future<Output = ()> + Send,
        FProc: Fn(Message) -> FProcFut + Send + Sync,
        FProcFut: Future<Output = Result<Option<Message>, ExtensionError>> + Send,
    {
        let timeout = self.config.receive_timeout;

        loop {
            let result = tokio::time::timeout(timeout, receiver.next()).await;

            match result {
                Ok(Some(Ok(msg))) => {
                    // Notify activity
                    on_activity().await;

                    // Process via processor
                    match processor(msg).await {
                        Ok(Some(broadcast_msg)) => {
                            let _ = self.message_tx.send(Arc::new(broadcast_msg));
                        }
                        Ok(None) => {
                            // filtered
                        }
                        Err(e) => match self.config.processor_error_policy {
                            ProcessorErrorPolicy::Ignore => {
                                tracing::warn!(error = ?e, "Message processor failed");
                            }
                            ProcessorErrorPolicy::Disconnect => {
                                return Err(ReceiveError::WebSocket(e.to_string()));
                            }
                        },
                    }
                }
                Ok(Some(Err(e))) => {
                    return Err(ReceiveError::WebSocket(e.to_string()));
                }
                Ok(None) => {
                    return Err(ReceiveError::StreamClosed);
                }
                Err(_) => {
                    return Err(ReceiveError::Timeout(timeout));
                }
            }
        }
    }
}

impl<S: WsStream> Default for MessageDispatcher<S> {
    fn default() -> Self {
        Self::new(DispatcherConfig::default())
    }
}
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_dispatcher_config() {
        let config = DispatcherConfig::new()
            .with_receive_timeout(Duration::from_secs(60))
            .with_broadcast_capacity(2048);

        assert_eq!(config.receive_timeout, Duration::from_secs(60));
        assert_eq!(config.broadcast_capacity, 2048);
    }

    #[tokio::test]
    async fn test_dispatcher_not_connected() {
        let dispatcher = MessageDispatcher::<crate::connection::DefaultWsStream>::default();

        // Should fail when not connected
        let result = dispatcher.send(Message::Text("test".into())).await;
        assert!(matches!(result, Err(SendError::NotConnected)));
    }
}