Skip to main content

atomr_streams_io/
ws.rs

1//! WebSocket frame `Source` with automatic reconnect.
2//!
3//! [`WsSource::connect`] opens a WebSocket connection via
4//! [`tokio_tungstenite::connect_async`], maps each inbound tungstenite
5//! [`Message`](tokio_tungstenite::tungstenite::Message) to a [`WsFrame`], and
6//! wraps the whole connect-and-stream cycle in
7//! [`RestartSource::with_backoff`](atomr_streams::RestartSource::with_backoff).
8//! When the connection drops (or the initial connect fails) the restart
9//! combinator re-runs the factory after a backoff governed by the supplied
10//! [`RestartSettings`], so the source
11//! transparently reconnects.
12//!
13//! Each connection attempt yields a [`Source`] of
14//! `Result<WsFrame, WsError>`: a failed connect produces a single
15//! `Err(WsError::Connect(_))` (after which restart backs off and retries),
16//! and a live connection produces `Ok(WsFrame::…)` per inbound message until
17//! the peer closes or errors.
18//!
19//! ## Rate limiting
20//!
21//! As with [`crate::http_poll`], rate mediation is left to the upstream
22//! operators in `atomr_streams::rate` (e.g. `token_bucket`) so policy is
23//! uniform across sources.
24
25use atomr_streams::{RestartSettings, RestartSource, Source};
26use bytes::Bytes;
27use futures_util::StreamExt;
28
29/// A WebSocket frame, decoupled from the underlying tungstenite types.
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum WsFrame {
32    /// A UTF-8 text frame.
33    Text(String),
34    /// A binary frame.
35    Binary(Bytes),
36    /// A ping control frame with its payload.
37    Ping(Bytes),
38    /// A pong control frame with its payload.
39    Pong(Bytes),
40    /// A close control frame (close reason discarded).
41    Close,
42}
43
44/// Errors raised by the WebSocket source.
45#[derive(Debug, thiserror::Error)]
46pub enum WsError {
47    /// Failed to establish (or re-establish) the connection.
48    #[error("websocket connect error: {0}")]
49    Connect(String),
50    /// A protocol / transport error on an established connection.
51    #[error("websocket protocol error: {0}")]
52    Protocol(String),
53}
54
55/// WebSocket source factory.
56pub struct WsSource;
57
58impl WsSource {
59    /// Connect to `url`, emitting `Ok(WsFrame)` per inbound message and
60    /// reconnecting with backoff per `restart` whenever the stream ends or a
61    /// connect attempt fails.
62    ///
63    /// A failed connect surfaces as a single `Err(WsError::Connect(_))` from
64    /// that attempt; the [`RestartSource`] then backs off and tries again
65    /// (subject to `restart.max_restarts`).
66    pub fn connect(url: url::Url, restart: RestartSettings) -> Source<Result<WsFrame, WsError>> {
67        RestartSource::with_backoff(restart, move || {
68            let url = url.clone();
69            connect_once(url)
70        })
71    }
72}
73
74/// Build a `Source` for a single connection attempt.
75///
76/// On connect failure the source is a single `Err(Connect(_))`; on success it
77/// streams mapped frames, terminating when the socket closes/errors so the
78/// outer [`RestartSource`] can reconnect.
79fn connect_once(url: url::Url) -> Source<Result<WsFrame, WsError>> {
80    // A bounded channel decouples the per-connection driver task from the
81    // consumer and bounds in-flight frames (backpressure).
82    let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<Result<WsFrame, WsError>>();
83
84    tokio::spawn(async move {
85        match tokio_tungstenite::connect_async(url.as_str()).await {
86            Err(e) => {
87                let _ = tx.send(Err(WsError::Connect(e.to_string())));
88                // Stream ends -> RestartSource reconnects after backoff.
89            }
90            Ok((ws, _resp)) => {
91                let (_write, mut read) = ws.split();
92                while let Some(msg) = read.next().await {
93                    match msg {
94                        Ok(m) => {
95                            if let Some(frame) = map_message(m) {
96                                if tx.send(Ok(frame)).is_err() {
97                                    return; // consumer gone
98                                }
99                            }
100                        }
101                        Err(e) => {
102                            let _ = tx.send(Err(WsError::Protocol(e.to_string())));
103                            return; // end this connection's stream
104                        }
105                    }
106                }
107            }
108        }
109    });
110
111    Source::from_receiver(rx)
112}
113
114/// Map a tungstenite [`Message`] to a [`WsFrame`].
115///
116/// Returns `None` for `Frame(_)` raw frames (an internal tungstenite variant
117/// not produced by the high-level read stream).
118fn map_message(m: tokio_tungstenite::tungstenite::Message) -> Option<WsFrame> {
119    use tokio_tungstenite::tungstenite::Message as M;
120    match m {
121        M::Text(s) => Some(WsFrame::Text(s)),
122        M::Binary(b) => Some(WsFrame::Binary(Bytes::from(b))),
123        M::Ping(b) => Some(WsFrame::Ping(Bytes::from(b))),
124        M::Pong(b) => Some(WsFrame::Pong(Bytes::from(b))),
125        M::Close(_) => Some(WsFrame::Close),
126        M::Frame(_) => None,
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use atomr_streams::Sink;
134    use std::time::Duration;
135    use tokio::net::TcpListener;
136
137    fn fast_restart(max: usize) -> RestartSettings {
138        RestartSettings {
139            min_backoff: Duration::from_millis(1),
140            max_backoff: Duration::from_millis(5),
141            random_factor: 0.0,
142            max_restarts: Some(max),
143        }
144    }
145
146    #[tokio::test]
147    async fn first_frame_is_text_against_local_server() {
148        // Local WS server: accept one connection, send one Text frame, close.
149        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
150        let addr = listener.local_addr().unwrap();
151        tokio::spawn(async move {
152            if let Ok((stream, _)) = listener.accept().await {
153                if let Ok(mut ws) = tokio_tungstenite::accept_async(stream).await {
154                    use futures_util::SinkExt;
155                    use tokio_tungstenite::tungstenite::Message;
156                    let _ = ws.send(Message::Text("hello".to_string())).await;
157                    let _ = ws.close(None).await;
158                }
159            }
160        });
161
162        let url = url::Url::parse(&format!("ws://{addr}/")).unwrap();
163        // Only one connection's worth of frames is needed.
164        let src = WsSource::connect(url, fast_restart(1));
165
166        let first = Sink::first(src).await.expect("expected one frame");
167        match first {
168            Ok(WsFrame::Text(s)) => assert_eq!(s, "hello"),
169            other => panic!("expected Ok(Text(\"hello\")), got {other:?}"),
170        }
171    }
172
173    #[tokio::test]
174    async fn refused_port_surfaces_connect_err() {
175        // Nothing listens on port 1 -> connect fails -> Err(Connect).
176        let url = url::Url::parse("ws://127.0.0.1:1/").unwrap();
177        let src = WsSource::connect(url, fast_restart(1));
178
179        let first = Sink::first(src).await.expect("expected one emission");
180        match first {
181            Err(WsError::Connect(_)) => {}
182            other => panic!("expected Err(Connect), got {other:?}"),
183        }
184    }
185}