1use futures::stream::{SplitSink, SplitStream};
2use futures::{SinkExt, StreamExt};
3use log::{error, info};
4use std::time::Duration;
5use thiserror::Error;
6use tokio::net::TcpStream;
7use tokio::sync::broadcast::error::{RecvError as BroadcastRecvError, SendError};
8use tokio::sync::oneshot::error::RecvError as OneshotRecvError;
9use tokio::sync::{broadcast, mpsc, oneshot};
10use tokio::task::JoinError;
11use tokio_tungstenite::tungstenite::protocol::frame::Frame;
12use tokio_tungstenite::tungstenite::protocol::CloseFrame;
13pub use tokio_tungstenite::tungstenite::Message as TungsteniteMessage;
14use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
15pub use url::Url;
16
17const INITIAL_BACKOFF_MILLIS: u64 = 100;
18const MAX_BACKOFF_MILLIS: u64 = 5 * 60 * 1000;
19const CHANNEL_CAPACITY: usize = 32;
20
21type WsWrite = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, TungsteniteMessage>;
22type WsRead = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
23type WsError = tokio_tungstenite::tungstenite::Error;
24
25#[derive(Debug, Error)]
26pub enum Error {
27 #[error("tungstenite: {0}")]
28 Tungstenite(WsError),
29 #[error("broadcast::recv: {0}")]
30 BroadcastRecv(BroadcastRecvError),
31 #[error("oneshot::recv: {0}")]
32 OneshotRecv(OneshotRecvError),
33 #[error("join: {0}")]
34 Join(JoinError),
35 #[error("broadcast::send: {0}")]
36 Send(SendError<Message>),
37}
38
39pub type WsMessageSender = mpsc::Sender<Message>;
40pub type WsMessageReceiver = broadcast::Receiver<Message>;
41
42#[derive(Debug, Eq, PartialEq, Clone)]
43pub enum Message {
44 Text(String),
45 Binary(Vec<u8>),
46 Ping(Vec<u8>),
47 Pong(Vec<u8>),
48 Close(Option<CloseFrame<'static>>),
49 Frame(Frame),
50 ConnectionOpened,
55 ConnectionClosed,
60}
61
62pub async fn connect_persistent_websocket_async(
63 url: Url,
64) -> Result<(WsMessageSender, WsMessageReceiver), Error> {
65 let (msg_tx_out, msg_rx_out) = mpsc::channel::<Message>(CHANNEL_CAPACITY);
66 let (msg_tx_in, msg_rx_in) = broadcast::channel::<Message>(CHANNEL_CAPACITY);
67 let (first_conn_tx, first_conn_rx) = oneshot::channel();
68 tokio::spawn(setup_persistent_websocket(
69 url.into(),
70 msg_tx_in,
71 msg_rx_out,
72 first_conn_tx,
73 ));
74 if let Some(first_conn_error) = first_conn_rx.await? {
75 return Err(Error::Tungstenite(first_conn_error));
76 }
77 Ok((msg_tx_out, msg_rx_in))
78}
79
80async fn setup_persistent_websocket(
81 url: Url,
82 mut msg_tx_in: broadcast::Sender<Message>,
83 mut msg_rx_out: mpsc::Receiver<Message>,
84 first_conn_tx: oneshot::Sender<Option<WsError>>,
85) {
86 let mut connection_count = 0;
87 let mut first_conn_tx = Some(first_conn_tx);
88 loop {
89 info!("connecting to {url}");
90 let result = listen_for_persistent_ws_messages(
91 url.clone(),
92 &mut msg_tx_in,
93 &mut msg_rx_out,
94 &mut first_conn_tx,
95 )
96 .await;
97 if let Err(e) = result {
98 error!("error during ws connection: {e}");
99 }
100 info!("disconnected from {url}");
101 connection_count += 1;
102 tokio::time::sleep(get_backoff(connection_count)).await;
103 }
104}
105
106async fn listen_for_persistent_ws_messages(
107 url: Url,
108 msg_tx_in: &mut broadcast::Sender<Message>,
109 msg_rx_out: &mut mpsc::Receiver<Message>,
110 first_conn_tx: &mut Option<oneshot::Sender<Option<WsError>>>,
111) -> Result<(), Error> {
112 let connection_result = connect_async(url).await;
113 let (socket, _) = match connection_result {
114 Ok(c) => {
115 if let Some(first_conn_tx) = first_conn_tx.take() {
116 first_conn_tx.send(None).expect("first_conn_rx dropped")
117 }
118 c
119 }
120 Err(e) => {
121 if let Some(first_conn_tx) = first_conn_tx.take() {
122 first_conn_tx.send(Some(e)).expect("first_conn_rx dropped");
123 }
124 return Ok(());
125 }
126 };
127 let (mut ws_tx, mut ws_rx) = socket.split();
128 msg_tx_in.send(Message::ConnectionOpened)?;
129 let result = process_connection(&mut ws_tx, &mut ws_rx, msg_tx_in, msg_rx_out).await;
130 msg_tx_in.send(Message::ConnectionClosed)?;
131 result?;
132 Ok(())
133}
134
135async fn process_connection(
136 ws_tx: &mut WsWrite,
137 ws_rx: &mut WsRead,
138 msg_tx_in: &mut broadcast::Sender<Message>,
139 msg_rx_out: &mut mpsc::Receiver<Message>,
140) -> Result<(), Error> {
141 loop {
142 tokio::select! {
143 Some(incoming_msg) = ws_rx.next() => {
144 let should_close = handle_incoming_message(incoming_msg, ws_tx, msg_tx_in).await?;
145 if should_close {
146 info!("closing connection");
147 break;
148 }
149 },
150 Some(outgoing_msg) = msg_rx_out.recv() => {
151 let Some(outgoing_msg) = outgoing_msg.to_tungstenite() else {
152 continue;
153 };
154 ws_tx.send(outgoing_msg).await?;
155 }
156 }
157 }
158 Ok(())
159}
160
161async fn handle_incoming_message(
162 message: Result<TungsteniteMessage, WsError>,
163 ws_tx: &mut WsWrite,
164 msg_tx_in: &mut broadcast::Sender<Message>,
165) -> Result<bool, Error> {
166 let message = match message {
167 Ok(m) => m,
168 Err(e) => {
169 error!("connection error: {e}");
170 return Ok(true);
171 }
172 };
173 match message {
174 TungsteniteMessage::Ping(_) => {
175 #[cfg(feature = "pong")]
176 if let Err(e) = ws_tx.send(TungsteniteMessage::Pong(vec![])).await {
177 error!("error sending pong: {e}");
178 }
179 return Ok(false);
180 }
181 TungsteniteMessage::Close(frame) => {
182 info!("received socket close signal: {:#?}", frame);
183 return Ok(true);
184 }
185 _ => {}
186 }
187 msg_tx_in.send(message.into())?;
188 Ok(false)
189}
190
191fn get_backoff(attempt: u64) -> Duration {
192 let backoff = INITIAL_BACKOFF_MILLIS * attempt.pow(2);
193 Duration::from_millis(backoff.min(MAX_BACKOFF_MILLIS))
194}
195
196impl From<WsError> for Error {
197 fn from(e: WsError) -> Self {
198 Self::Tungstenite(e)
199 }
200}
201
202impl From<BroadcastRecvError> for Error {
203 fn from(e: BroadcastRecvError) -> Self {
204 Self::BroadcastRecv(e)
205 }
206}
207
208impl From<OneshotRecvError> for Error {
209 fn from(e: OneshotRecvError) -> Self {
210 Self::OneshotRecv(e)
211 }
212}
213
214impl From<JoinError> for Error {
215 fn from(e: JoinError) -> Self {
216 Self::Join(e)
217 }
218}
219
220impl From<SendError<Message>> for Error {
221 fn from(e: SendError<Message>) -> Self {
222 Self::Send(e)
223 }
224}
225
226impl Message {
227 fn to_tungstenite(self) -> Option<TungsteniteMessage> {
228 match self {
229 Message::Text(v) => Some(TungsteniteMessage::Text(v)),
230 Message::Binary(v) => Some(TungsteniteMessage::Binary(v)),
231 Message::Ping(v) => Some(TungsteniteMessage::Ping(v)),
232 Message::Pong(v) => Some(TungsteniteMessage::Pong(v)),
233 Message::Close(v) => Some(TungsteniteMessage::Close(v)),
234 Message::Frame(v) => Some(TungsteniteMessage::Frame(v)),
235 Message::ConnectionOpened => None,
236 Message::ConnectionClosed => None,
237 }
238 }
239}
240
241impl From<TungsteniteMessage> for Message {
242 fn from(message: TungsteniteMessage) -> Self {
243 match message {
244 TungsteniteMessage::Text(v) => Message::Text(v),
245 TungsteniteMessage::Binary(v) => Message::Binary(v),
246 TungsteniteMessage::Ping(v) => Message::Ping(v),
247 TungsteniteMessage::Pong(v) => Message::Pong(v),
248 TungsteniteMessage::Close(v) => Message::Close(v),
249 TungsteniteMessage::Frame(v) => Message::Frame(v),
250 }
251 }
252}