stream_tungstenite/message/
dispatcher.rs

1//! Message dispatcher - handles message routing and distribution.
2
3use futures_util::stream::{SplitSink, SplitStream};
4use futures_util::{SinkExt, StreamExt};
5use std::future::Future;
6use std::marker::PhantomData;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::{broadcast, mpsc, RwLock};
11use tokio::task::JoinHandle;
12use tungstenite::Message;
13
14/// Shared message type for zero-copy broadcasting
15pub type SharedMessage = Arc<Message>;
16
17use crate::connection::WsStream;
18use crate::error::{ExtensionError, ReceiveError, SendError};
19
20/// Message dispatcher configuration
21#[derive(Debug, Clone)]
22pub struct DispatcherConfig {
23    /// Receive timeout
24    pub receive_timeout: Duration,
25    /// Message broadcast channel capacity
26    pub broadcast_capacity: usize,
27    /// Outgoing send buffer capacity used by the dispatcher internal queue
28    pub send_buffer_capacity: usize,
29    /// Policy for handling processor errors
30    pub processor_error_policy: ProcessorErrorPolicy,
31}
32
33/// Policy for handling processor (extension) errors on receive path
34#[derive(Debug, Clone, Copy)]
35pub enum ProcessorErrorPolicy {
36    /// Ignore the error (log and continue)
37    Ignore,
38    /// Treat as fatal and disconnect (bubble up an error)
39    Disconnect,
40}
41
42impl Default for DispatcherConfig {
43    fn default() -> Self {
44        Self {
45            receive_timeout: Duration::from_secs(30),
46            broadcast_capacity: 1024,
47            send_buffer_capacity: 256,
48            processor_error_policy: ProcessorErrorPolicy::Ignore,
49        }
50    }
51}
52
53impl DispatcherConfig {
54    /// Create a new dispatcher config
55    #[must_use]
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    /// Set receive timeout
61    #[must_use]
62    pub const fn with_receive_timeout(mut self, timeout: Duration) -> Self {
63        self.receive_timeout = timeout;
64        self
65    }
66
67    /// Set broadcast channel capacity
68    #[must_use]
69    pub const fn with_broadcast_capacity(mut self, capacity: usize) -> Self {
70        self.broadcast_capacity = capacity;
71        self
72    }
73
74    /// Set send buffer capacity for dispatcher's internal queue
75    #[must_use]
76    pub const fn with_send_buffer_capacity(mut self, capacity: usize) -> Self {
77        self.send_buffer_capacity = capacity;
78        self
79    }
80
81    /// Set processor error handling policy
82    #[must_use]
83    pub const fn with_processor_error_policy(mut self, policy: ProcessorErrorPolicy) -> Self {
84        self.processor_error_policy = policy;
85        self
86    }
87}
88
89/// Internal sender state managed by the dispatcher
90struct SenderState<S: WsStream> {
91    /// Background task that owns the `SplitSink` and pulls from `send_rx`
92    send_task: Option<JoinHandle<()>>,
93    /// Internal queue for outgoing messages
94    send_tx: Option<mpsc::Sender<Message>>,
95    /// Marker for the stream type
96    _marker: PhantomData<S>,
97}
98
99/// Message dispatcher - handles sending and receiving messages
100pub struct MessageDispatcher<S: WsStream = crate::connection::DefaultWsStream> {
101    /// Configuration
102    config: DispatcherConfig,
103    /// Sender state (protected by `RwLock` for concurrent attach/detach)
104    sender_state: Arc<RwLock<SenderState<S>>>,
105    /// Fast path connection check
106    is_connected: Arc<AtomicBool>,
107    /// Message broadcaster (Arc-wrapped for zero-copy)
108    message_tx: broadcast::Sender<SharedMessage>,
109}
110
111// Note: `future_not_send` is intentionally allowed here because the dispatcher
112// is designed to be used within single-threaded async contexts where the stream
113// type `S` may not implement `Sync`. The internal `RwLock` provides safe concurrent
114// access within the same runtime.
115#[allow(clippy::future_not_send)]
116impl<S: WsStream> MessageDispatcher<S> {
117    /// Create a new message dispatcher
118    #[must_use]
119    pub fn new(config: DispatcherConfig) -> Self {
120        let (message_tx, _) = broadcast::channel(config.broadcast_capacity);
121
122        Self {
123            config,
124            sender_state: Arc::new(RwLock::new(SenderState::<S> {
125                send_task: None,
126                send_tx: None,
127                _marker: PhantomData,
128            })),
129            is_connected: Arc::new(AtomicBool::new(false)),
130            message_tx,
131        }
132    }
133
134    /// Attach a sender (called when connection is established)
135    pub async fn attach(&self, sender: SplitSink<S, Message>) {
136        // Create internal queue
137        let (tx, mut rx) = mpsc::channel::<Message>(self.config.send_buffer_capacity);
138
139        // Spawn background send task that owns the sink
140        let connected = self.is_connected.clone();
141        let send_task = tokio::spawn(async move {
142            let mut sink = sender;
143            while let Some(msg) = rx.recv().await {
144                // On send error, mark disconnected and stop the task
145                if let Err(e) = sink.send(msg).await {
146                    tracing::debug!(error = ?e, "Dispatcher send task encountered error");
147                    connected.store(false, Ordering::Release);
148                    break;
149                }
150            }
151        });
152
153        // Publish state
154        {
155            let mut state = self.sender_state.write().await;
156            // Clean up any previous task/channel if present
157            if let Some(handle) = state.send_task.take() {
158                handle.abort();
159            }
160            state.send_tx = Some(tx);
161            state.send_task = Some(send_task);
162        }
163        // Set connected after state is visible
164        self.is_connected.store(true, Ordering::Release);
165        tracing::debug!("Message dispatcher attached");
166    }
167
168    /// Detach the sender (called when connection is lost)
169    pub async fn detach(&self) {
170        self.is_connected.store(false, Ordering::Release);
171        {
172            let mut state = self.sender_state.write().await;
173            // Drop the channel to stop producers
174            state.send_tx = None;
175            // Abort background task
176            if let Some(handle) = state.send_task.take() {
177                handle.abort();
178            }
179        }
180        tracing::debug!("Message dispatcher detached");
181    }
182
183    /// Check if connected (fast path, no lock)
184    #[must_use]
185    pub fn is_connected(&self) -> bool {
186        self.is_connected.load(Ordering::Acquire)
187    }
188
189    /// Send a message
190    ///
191    /// # Errors
192    ///
193    /// - Returns [`SendError::NotConnected`] if not currently connected.
194    /// - Returns [`SendError::ChannelClosed`] if the internal send queue is closed.
195    pub async fn send(&self, msg: Message) -> Result<(), SendError> {
196        // Fast path
197        if !self.is_connected() {
198            return Err(SendError::NotConnected);
199        }
200        // Clone tx without holding the lock across await
201        let tx = {
202            let state = self.sender_state.read().await;
203            state.send_tx.clone()
204        };
205        match tx {
206            Some(tx) => tx.send(msg).await.map_err(|_| SendError::ChannelClosed),
207            None => Err(SendError::NotConnected),
208        }
209    }
210
211    /// Subscribe to messages
212    ///
213    /// Returns a receiver for shared messages. Messages are wrapped in `Arc<Message>`
214    /// for zero-copy broadcasting. To get owned `Message`:
215    /// - Read-only access: `msg.as_ref()`
216    /// - Need ownership: `Arc::try_unwrap(msg).unwrap_or_else(|arc| (*arc).clone())`
217    #[must_use]
218    pub fn subscribe(&self) -> broadcast::Receiver<SharedMessage> {
219        self.message_tx.subscribe()
220    }
221
222    /// Get the number of message subscribers
223    #[must_use]
224    pub fn subscriber_count(&self) -> usize {
225        self.message_tx.receiver_count()
226    }
227
228    /// Run the receive loop
229    ///
230    /// This consumes messages from the receiver and broadcasts them to subscribers.
231    /// Returns when the connection is closed or an error occurs.
232    ///
233    /// # Errors
234    ///
235    /// - Returns [`ReceiveError::WebSocket`] if a WebSocket error occurs.
236    /// - Returns [`ReceiveError::StreamClosed`] if the stream is closed.
237    /// - Returns [`ReceiveError::Timeout`] if no message is received within the configured timeout.
238    pub async fn receive_loop(&self, mut receiver: SplitStream<S>) -> Result<(), ReceiveError> {
239        let timeout = self.config.receive_timeout;
240
241        loop {
242            let result = tokio::time::timeout(timeout, receiver.next()).await;
243
244            match result {
245                Ok(Some(Ok(msg))) => {
246                    // Broadcast message to all subscribers (zero-copy with Arc)
247                    // Ignore send errors (no subscribers)
248                    let _ = self.message_tx.send(Arc::new(msg));
249                }
250                Ok(Some(Err(e))) => {
251                    tracing::debug!(error = ?e, "WebSocket receive error");
252                    return Err(ReceiveError::WebSocket(e.to_string()));
253                }
254                Ok(None) => {
255                    tracing::debug!("WebSocket stream closed");
256                    return Err(ReceiveError::StreamClosed);
257                }
258                Err(_) => {
259                    tracing::debug!(timeout = ?timeout, "Receive timeout");
260                    return Err(ReceiveError::Timeout(timeout));
261                }
262            }
263        }
264    }
265
266    /// Receive loop with activity callback
267    ///
268    /// Calls the provided callback on each received message for activity tracking.
269    ///
270    /// # Errors
271    ///
272    /// - Returns [`ReceiveError::WebSocket`] if a WebSocket error occurs.
273    /// - Returns [`ReceiveError::StreamClosed`] if the stream is closed.
274    /// - Returns [`ReceiveError::Timeout`] if no message is received within the configured timeout.
275    pub async fn receive_loop_with_activity<F>(
276        &self,
277        mut receiver: SplitStream<S>,
278        on_activity: F,
279    ) -> Result<(), ReceiveError>
280    where
281        F: Fn() + Send + Sync,
282    {
283        let timeout = self.config.receive_timeout;
284
285        loop {
286            let result = tokio::time::timeout(timeout, receiver.next()).await;
287
288            match result {
289                Ok(Some(Ok(msg))) => {
290                    // Notify activity
291                    on_activity();
292
293                    // Broadcast message (zero-copy with Arc)
294                    let _ = self.message_tx.send(Arc::new(msg));
295                }
296                Ok(Some(Err(e))) => {
297                    return Err(ReceiveError::WebSocket(e.to_string()));
298                }
299                Ok(None) => {
300                    return Err(ReceiveError::StreamClosed);
301                }
302                Err(_) => {
303                    return Err(ReceiveError::Timeout(timeout));
304                }
305            }
306        }
307    }
308
309    /// Receive loop with async activity callback and async processor
310    ///
311    /// The processor can transform or filter messages. Returning Ok(Some(msg)) broadcasts it,
312    /// Ok(None) drops it, Err(_) logs and continues.
313    ///
314    /// # Errors
315    ///
316    /// - Returns [`ReceiveError::WebSocket`] if a WebSocket error occurs.
317    /// - Returns [`ReceiveError::StreamClosed`] if the stream is closed.
318    /// - Returns [`ReceiveError::Timeout`] if no message is received within the configured timeout.
319    pub async fn receive_loop_with_processor<FAct, FActFut, FProc, FProcFut>(
320        &self,
321        mut receiver: SplitStream<S>,
322        on_activity: FAct,
323        processor: FProc,
324    ) -> Result<(), ReceiveError>
325    where
326        FAct: Fn() -> FActFut + Send + Sync,
327        FActFut: Future<Output = ()> + Send,
328        FProc: Fn(Message) -> FProcFut + Send + Sync,
329        FProcFut: Future<Output = Result<Option<Message>, ExtensionError>> + Send,
330    {
331        let timeout = self.config.receive_timeout;
332
333        loop {
334            let result = tokio::time::timeout(timeout, receiver.next()).await;
335
336            match result {
337                Ok(Some(Ok(msg))) => {
338                    // Notify activity
339                    on_activity().await;
340
341                    // Process via processor
342                    match processor(msg).await {
343                        Ok(Some(broadcast_msg)) => {
344                            let _ = self.message_tx.send(Arc::new(broadcast_msg));
345                        }
346                        Ok(None) => {
347                            // filtered
348                        }
349                        Err(e) => match self.config.processor_error_policy {
350                            ProcessorErrorPolicy::Ignore => {
351                                tracing::warn!(error = ?e, "Message processor failed");
352                            }
353                            ProcessorErrorPolicy::Disconnect => {
354                                return Err(ReceiveError::WebSocket(e.to_string()));
355                            }
356                        },
357                    }
358                }
359                Ok(Some(Err(e))) => {
360                    return Err(ReceiveError::WebSocket(e.to_string()));
361                }
362                Ok(None) => {
363                    return Err(ReceiveError::StreamClosed);
364                }
365                Err(_) => {
366                    return Err(ReceiveError::Timeout(timeout));
367                }
368            }
369        }
370    }
371}
372
373impl<S: WsStream> Default for MessageDispatcher<S> {
374    fn default() -> Self {
375        Self::new(DispatcherConfig::default())
376    }
377}
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    #[test]
383    fn test_dispatcher_config() {
384        let config = DispatcherConfig::new()
385            .with_receive_timeout(Duration::from_secs(60))
386            .with_broadcast_capacity(2048);
387
388        assert_eq!(config.receive_timeout, Duration::from_secs(60));
389        assert_eq!(config.broadcast_capacity, 2048);
390    }
391
392    #[tokio::test]
393    async fn test_dispatcher_not_connected() {
394        let dispatcher = MessageDispatcher::<crate::connection::DefaultWsStream>::default();
395
396        // Should fail when not connected
397        let result = dispatcher.send(Message::Text("test".into())).await;
398        assert!(matches!(result, Err(SendError::NotConnected)));
399    }
400}