atomr_streams_io/ws.rs
1//! WebSocket frame `Source` with automatic reconnect.
2//!
3//! [`WsSource::connect`] opens a WebSocket connection via
4//! [`tokio_tungstenite::connect_async`], maps each inbound tungstenite
5//! [`Message`](tokio_tungstenite::tungstenite::Message) to a [`WsFrame`], and
6//! wraps the whole connect-and-stream cycle in
7//! [`RestartSource::with_backoff`](atomr_streams::RestartSource::with_backoff).
8//! When the connection drops (or the initial connect fails) the restart
9//! combinator re-runs the factory after a backoff governed by the supplied
10//! [`RestartSettings`], so the source
11//! transparently reconnects.
12//!
13//! Each connection attempt yields a [`Source`] of
14//! `Result<WsFrame, WsError>`: a failed connect produces a single
15//! `Err(WsError::Connect(_))` (after which restart backs off and retries),
16//! and a live connection produces `Ok(WsFrame::…)` per inbound message until
17//! the peer closes or errors.
18//!
19//! ## Rate limiting
20//!
21//! As with [`crate::http_poll`], rate mediation is left to the upstream
22//! operators in `atomr_streams::rate` (e.g. `token_bucket`) so policy is
23//! uniform across sources.
24
25use atomr_streams::{RestartSettings, RestartSource, Source};
26use bytes::Bytes;
27use futures_util::StreamExt;
28
29/// A WebSocket frame, decoupled from the underlying tungstenite types.
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum WsFrame {
32 /// A UTF-8 text frame.
33 Text(String),
34 /// A binary frame.
35 Binary(Bytes),
36 /// A ping control frame with its payload.
37 Ping(Bytes),
38 /// A pong control frame with its payload.
39 Pong(Bytes),
40 /// A close control frame (close reason discarded).
41 Close,
42}
43
44/// Errors raised by the WebSocket source.
45#[derive(Debug, thiserror::Error)]
46pub enum WsError {
47 /// Failed to establish (or re-establish) the connection.
48 #[error("websocket connect error: {0}")]
49 Connect(String),
50 /// A protocol / transport error on an established connection.
51 #[error("websocket protocol error: {0}")]
52 Protocol(String),
53}
54
55/// WebSocket source factory.
56pub struct WsSource;
57
58impl WsSource {
59 /// Connect to `url`, emitting `Ok(WsFrame)` per inbound message and
60 /// reconnecting with backoff per `restart` whenever the stream ends or a
61 /// connect attempt fails.
62 ///
63 /// A failed connect surfaces as a single `Err(WsError::Connect(_))` from
64 /// that attempt; the [`RestartSource`] then backs off and tries again
65 /// (subject to `restart.max_restarts`).
66 pub fn connect(url: url::Url, restart: RestartSettings) -> Source<Result<WsFrame, WsError>> {
67 RestartSource::with_backoff(restart, move || {
68 let url = url.clone();
69 connect_once(url)
70 })
71 }
72}
73
74/// Build a `Source` for a single connection attempt.
75///
76/// On connect failure the source is a single `Err(Connect(_))`; on success it
77/// streams mapped frames, terminating when the socket closes/errors so the
78/// outer [`RestartSource`] can reconnect.
79fn connect_once(url: url::Url) -> Source<Result<WsFrame, WsError>> {
80 // A bounded channel decouples the per-connection driver task from the
81 // consumer and bounds in-flight frames (backpressure).
82 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<Result<WsFrame, WsError>>();
83
84 tokio::spawn(async move {
85 match tokio_tungstenite::connect_async(url.as_str()).await {
86 Err(e) => {
87 let _ = tx.send(Err(WsError::Connect(e.to_string())));
88 // Stream ends -> RestartSource reconnects after backoff.
89 }
90 Ok((ws, _resp)) => {
91 let (_write, mut read) = ws.split();
92 while let Some(msg) = read.next().await {
93 match msg {
94 Ok(m) => {
95 if let Some(frame) = map_message(m) {
96 if tx.send(Ok(frame)).is_err() {
97 return; // consumer gone
98 }
99 }
100 }
101 Err(e) => {
102 let _ = tx.send(Err(WsError::Protocol(e.to_string())));
103 return; // end this connection's stream
104 }
105 }
106 }
107 }
108 }
109 });
110
111 Source::from_receiver(rx)
112}
113
114/// Map a tungstenite [`Message`] to a [`WsFrame`].
115///
116/// Returns `None` for `Frame(_)` raw frames (an internal tungstenite variant
117/// not produced by the high-level read stream).
118fn map_message(m: tokio_tungstenite::tungstenite::Message) -> Option<WsFrame> {
119 use tokio_tungstenite::tungstenite::Message as M;
120 match m {
121 M::Text(s) => Some(WsFrame::Text(s)),
122 M::Binary(b) => Some(WsFrame::Binary(Bytes::from(b))),
123 M::Ping(b) => Some(WsFrame::Ping(Bytes::from(b))),
124 M::Pong(b) => Some(WsFrame::Pong(Bytes::from(b))),
125 M::Close(_) => Some(WsFrame::Close),
126 M::Frame(_) => None,
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use atomr_streams::Sink;
134 use std::time::Duration;
135 use tokio::net::TcpListener;
136
137 fn fast_restart(max: usize) -> RestartSettings {
138 RestartSettings {
139 min_backoff: Duration::from_millis(1),
140 max_backoff: Duration::from_millis(5),
141 random_factor: 0.0,
142 max_restarts: Some(max),
143 }
144 }
145
146 #[tokio::test]
147 async fn first_frame_is_text_against_local_server() {
148 // Local WS server: accept one connection, send one Text frame, close.
149 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
150 let addr = listener.local_addr().unwrap();
151 tokio::spawn(async move {
152 if let Ok((stream, _)) = listener.accept().await {
153 if let Ok(mut ws) = tokio_tungstenite::accept_async(stream).await {
154 use futures_util::SinkExt;
155 use tokio_tungstenite::tungstenite::Message;
156 let _ = ws.send(Message::Text("hello".to_string())).await;
157 let _ = ws.close(None).await;
158 }
159 }
160 });
161
162 let url = url::Url::parse(&format!("ws://{addr}/")).unwrap();
163 // Only one connection's worth of frames is needed.
164 let src = WsSource::connect(url, fast_restart(1));
165
166 let first = Sink::first(src).await.expect("expected one frame");
167 match first {
168 Ok(WsFrame::Text(s)) => assert_eq!(s, "hello"),
169 other => panic!("expected Ok(Text(\"hello\")), got {other:?}"),
170 }
171 }
172
173 #[tokio::test]
174 async fn refused_port_surfaces_connect_err() {
175 // Nothing listens on port 1 -> connect fails -> Err(Connect).
176 let url = url::Url::parse("ws://127.0.0.1:1/").unwrap();
177 let src = WsSource::connect(url, fast_restart(1));
178
179 let first = Sink::first(src).await.expect("expected one emission");
180 match first {
181 Err(WsError::Connect(_)) => {}
182 other => panic!("expected Err(Connect), got {other:?}"),
183 }
184 }
185}